summaryrefslogtreecommitdiff
path: root/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--lib/Analysis/ScalarEvolution.cpp933
1 files changed, 645 insertions, 288 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp
index 9539fd7c7559..0b8604187121 100644
--- a/lib/Analysis/ScalarEvolution.cpp
+++ b/lib/Analysis/ScalarEvolution.cpp
@@ -59,12 +59,23 @@
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/ADT/FoldingSet.h"
+#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
@@ -72,28 +83,55 @@
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constant.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Dominators.h"
-#include "llvm/IR/GetElementPtrTypeIterator.h"
+#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalAlias.h"
+#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Use.h"
+#include "llvm/IR/User.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/KnownBits.h"
-#include "llvm/Support/MathExtras.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
+#include <cassert>
+#include <climits>
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <map>
+#include <memory>
+#include <tuple>
+#include <utility>
+#include <vector>
+
using namespace llvm;
#define DEBUG_TYPE "scalar-evolution"
@@ -115,11 +153,11 @@ MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
cl::init(100));
// FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean.
+static cl::opt<bool> VerifySCEV(
+ "verify-scev", cl::Hidden,
+ cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
static cl::opt<bool>
-VerifySCEV("verify-scev",
- cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
-static cl::opt<bool>
- VerifySCEVMap("verify-scev-maps",
+ VerifySCEVMap("verify-scev-maps", cl::Hidden,
cl::desc("Verify no dangling value in ScalarEvolution's "
"ExprValueMap (slow)"));
@@ -415,9 +453,6 @@ void SCEVUnknown::deleted() {
}
void SCEVUnknown::allUsesReplacedWith(Value *New) {
- // Clear this SCEVUnknown from various maps.
- SE->forgetMemoizedResults(this);
-
// Remove this SCEVUnknown from the uniquing map.
SE->UniqueSCEVs.RemoveNode(this);
@@ -514,10 +549,10 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
/// Since we do not continue running this routine on expression trees once we
/// have seen unequal values, there is no need to track them in the cache.
static int
-CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache,
+CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,
const LoopInfo *const LI, Value *LV, Value *RV,
unsigned Depth) {
- if (Depth > MaxValueCompareDepth || EqCache.count({LV, RV}))
+ if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
return 0;
// Order pointer values after integer values. This helps SCEVExpander form
@@ -577,14 +612,14 @@ CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache,
for (unsigned Idx : seq(0u, LNumOps)) {
int Result =
- CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx),
+ CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
RInst->getOperand(Idx), Depth + 1);
if (Result != 0)
return Result;
}
}
- EqCache.insert({LV, RV});
+ EqCacheValue.unionSets(LV, RV);
return 0;
}
@@ -592,7 +627,8 @@ CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache,
// than RHS, respectively. A three-way result allows recursive comparisons to be
// more efficient.
static int CompareSCEVComplexity(
- SmallSet<std::pair<const SCEV *, const SCEV *>, 8> &EqCacheSCEV,
+ EquivalenceClasses<const SCEV *> &EqCacheSCEV,
+ EquivalenceClasses<const Value *> &EqCacheValue,
const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS,
DominatorTree &DT, unsigned Depth = 0) {
// Fast-path: SCEVs are uniqued so we can do a quick equality check.
@@ -604,7 +640,7 @@ static int CompareSCEVComplexity(
if (LType != RType)
return (int)LType - (int)RType;
- if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.count({LHS, RHS}))
+ if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.isEquivalent(LHS, RHS))
return 0;
// Aside from the getSCEVType() ordering, the particular ordering
// isn't very important except that it's beneficial to be consistent,
@@ -614,11 +650,10 @@ static int CompareSCEVComplexity(
const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
- SmallSet<std::pair<Value *, Value *>, 8> EqCache;
- int X = CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue(),
- Depth + 1);
+ int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
+ RU->getValue(), Depth + 1);
if (X == 0)
- EqCacheSCEV.insert({LHS, RHS});
+ EqCacheSCEV.unionSets(LHS, RHS);
return X;
}
@@ -659,14 +694,19 @@ static int CompareSCEVComplexity(
if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
+ // Compare NoWrap flags.
+ if (LA->getNoWrapFlags() != RA->getNoWrapFlags())
+ return (int)LA->getNoWrapFlags() - (int)RA->getNoWrapFlags();
+
// Lexicographically compare.
for (unsigned i = 0; i != LNumOps; ++i) {
- int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i),
- RA->getOperand(i), DT, Depth + 1);
+ int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
+ LA->getOperand(i), RA->getOperand(i), DT,
+ Depth + 1);
if (X != 0)
return X;
}
- EqCacheSCEV.insert({LHS, RHS});
+ EqCacheSCEV.unionSets(LHS, RHS);
return 0;
}
@@ -682,15 +722,18 @@ static int CompareSCEVComplexity(
if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
+ // Compare NoWrap flags.
+ if (LC->getNoWrapFlags() != RC->getNoWrapFlags())
+ return (int)LC->getNoWrapFlags() - (int)RC->getNoWrapFlags();
+
for (unsigned i = 0; i != LNumOps; ++i) {
- if (i >= RNumOps)
- return 1;
- int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i),
- RC->getOperand(i), DT, Depth + 1);
+ int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
+ LC->getOperand(i), RC->getOperand(i), DT,
+ Depth + 1);
if (X != 0)
return X;
}
- EqCacheSCEV.insert({LHS, RHS});
+ EqCacheSCEV.unionSets(LHS, RHS);
return 0;
}
@@ -699,14 +742,14 @@ static int CompareSCEVComplexity(
const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
// Lexicographically compare udiv expressions.
- int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(),
- DT, Depth + 1);
+ int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
+ RC->getLHS(), DT, Depth + 1);
if (X != 0)
return X;
- X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(), DT,
- Depth + 1);
+ X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
+ RC->getRHS(), DT, Depth + 1);
if (X == 0)
- EqCacheSCEV.insert({LHS, RHS});
+ EqCacheSCEV.unionSets(LHS, RHS);
return X;
}
@@ -717,10 +760,11 @@ static int CompareSCEVComplexity(
const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
// Compare cast expressions by operand.
- int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(),
- RC->getOperand(), DT, Depth + 1);
+ int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
+ LC->getOperand(), RC->getOperand(), DT,
+ Depth + 1);
if (X == 0)
- EqCacheSCEV.insert({LHS, RHS});
+ EqCacheSCEV.unionSets(LHS, RHS);
return X;
}
@@ -739,26 +783,26 @@ static int CompareSCEVComplexity(
/// results from this routine. In other words, we don't want the results of
/// this to depend on where the addresses of various SCEV objects happened to
/// land in memory.
-///
static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
LoopInfo *LI, DominatorTree &DT) {
if (Ops.size() < 2) return; // Noop
- SmallSet<std::pair<const SCEV *, const SCEV *>, 8> EqCache;
+ EquivalenceClasses<const SCEV *> EqCacheSCEV;
+ EquivalenceClasses<const Value *> EqCacheValue;
if (Ops.size() == 2) {
// This is the common case, which also happens to be trivially simple.
// Special case it.
const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
- if (CompareSCEVComplexity(EqCache, LI, RHS, LHS, DT) < 0)
+ if (CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, RHS, LHS, DT) < 0)
std::swap(LHS, RHS);
return;
}
// Do the rough sort by complexity.
std::stable_sort(Ops.begin(), Ops.end(),
- [&EqCache, LI, &DT](const SCEV *LHS, const SCEV *RHS) {
- return
- CompareSCEVComplexity(EqCache, LI, LHS, RHS, DT) < 0;
+ [&](const SCEV *LHS, const SCEV *RHS) {
+ return CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
+ LHS, RHS, DT) < 0;
});
// Now that we are sorted by complexity, group elements of the same
@@ -785,14 +829,16 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
// Returns the size of the SCEV S.
static inline int sizeOfSCEV(const SCEV *S) {
struct FindSCEVSize {
- int Size;
- FindSCEVSize() : Size(0) {}
+ int Size = 0;
+
+ FindSCEVSize() = default;
bool follow(const SCEV *S) {
++Size;
// Keep looking at all operands of S.
return true;
}
+
bool isDone() const {
return false;
}
@@ -1032,7 +1078,7 @@ private:
const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One;
};
-}
+} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Simple SCEV method implementations
@@ -1157,7 +1203,6 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
///
/// where BC(It, k) stands for binomial coefficient.
-///
const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
ScalarEvolution &SE) const {
const SCEV *Result = getStart();
@@ -1256,6 +1301,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
return S;
}
@@ -1343,7 +1389,8 @@ struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
-}
+
+} // end anonymous namespace
// The recurrence AR has been shown to have no signed/unsigned wrap or something
// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
@@ -1473,7 +1520,6 @@ static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
//
// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
// is `Delta` (defined below).
-//
template <typename ExtendOpTy>
bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
const SCEV *Step,
@@ -1484,7 +1530,6 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
// time here. It is correct (but more expensive) to continue with a
// non-constant `Start` and do a general SCEV subtraction to compute
// `PreStart` below.
- //
const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
if (!StartC)
return false;
@@ -1547,6 +1592,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
return S;
}
@@ -1733,6 +1779,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
return S;
}
@@ -1770,6 +1817,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
return S;
}
@@ -1981,12 +2029,12 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
return S;
}
/// getAnyExtendExpr - Return a SCEV for the given operand extended with
/// unspecified bits out to the given type.
-///
const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
Type *Ty) {
assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
@@ -2057,7 +2105,6 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
/// may be exposed. This helps getAddRecExpr short-circuit extra work in
/// the common case where no interesting opportunities are present, and
/// is also used as a check to avoid infinite recursion.
-///
static bool
CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
SmallVectorImpl<const SCEV *> &NewOps,
@@ -2132,7 +2179,8 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
const SmallVectorImpl<const SCEV *> &Ops,
SCEV::NoWrapFlags Flags) {
using namespace std::placeholders;
- typedef OverflowingBinaryOperator OBO;
+
+ using OBO = OverflowingBinaryOperator;
bool CanAnalyze =
Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
@@ -2306,12 +2354,23 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
// Check for truncates. If all the operands are truncated from the same
// type, see if factoring out the truncate would permit the result to be
- // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
+ // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
// if the contents of the resulting outer trunc fold to something simple.
- for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
- const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
- Type *DstType = Trunc->getType();
- Type *SrcType = Trunc->getOperand()->getType();
+ auto FindTruncSrcType = [&]() -> Type * {
+ // We're ultimately looking to fold an addrec of truncs and muls of only
+ // constants and truncs, so if we find any other types of SCEV
+ // as operands of the addrec then we bail and return nullptr here.
+ // Otherwise, we return the type of the operand of a trunc that we find.
+ if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
+ return T->getOperand()->getType();
+ if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
+ const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
+ if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
+ return T->getOperand()->getType();
+ }
+ return nullptr;
+ };
+ if (auto *SrcType = FindTruncSrcType()) {
SmallVector<const SCEV *, 8> LargeOps;
bool Ok = true;
// Check all the operands to see if they can be represented in the
@@ -2354,7 +2413,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1);
// If it folds to something simple, use it. Otherwise, don't.
if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
- return getTruncateExpr(Fold, DstType);
+ return getTruncateExpr(Fold, Ty);
}
}
@@ -2608,8 +2667,8 @@ ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops,
SCEV::NoWrapFlags Flags) {
FoldingSetNodeID ID;
ID.AddInteger(scAddExpr);
- for (unsigned i = 0, e = Ops.size(); i != e; ++i)
- ID.AddPointer(Ops[i]);
+ for (const SCEV *Op : Ops)
+ ID.AddPointer(Op);
void *IP = nullptr;
SCEVAddExpr *S =
static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
@@ -2619,6 +2678,7 @@ ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops,
S = new (SCEVAllocator)
SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
}
S->setNoWrapFlags(Flags);
return S;
@@ -2640,6 +2700,7 @@ ScalarEvolution::getOrCreateMulExpr(SmallVectorImpl<const SCEV *> &Ops,
S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
}
S->setNoWrapFlags(Flags);
return S;
@@ -2679,20 +2740,24 @@ static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
/// Determine if any of the operands in this SCEV are a constant or if
/// any of the add or multiply expressions in this SCEV contain a constant.
-static bool containsConstantSomewhere(const SCEV *StartExpr) {
- SmallVector<const SCEV *, 4> Ops;
- Ops.push_back(StartExpr);
- while (!Ops.empty()) {
- const SCEV *CurrentExpr = Ops.pop_back_val();
- if (isa<SCEVConstant>(*CurrentExpr))
- return true;
+static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
+ struct FindConstantInAddMulChain {
+ bool FoundConstant = false;
- if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) {
- const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr);
- Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end());
+ bool follow(const SCEV *S) {
+ FoundConstant |= isa<SCEVConstant>(S);
+ return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
}
- }
- return false;
+
+ bool isDone() const {
+ return FoundConstant;
+ }
+ };
+
+ FindConstantInAddMulChain F;
+ SCEVTraversal<FindConstantInAddMulChain> ST(F);
+ ST.visitAll(StartExpr);
+ return F.FoundConstant;
}
/// Get a canonical multiply expression, or something simpler if possible.
@@ -2729,7 +2794,11 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
// If any of Add's ops are Adds or Muls with a constant,
// apply this transformation as well.
if (Add->getNumOperands() == 2)
- if (containsConstantSomewhere(Add))
+ // TODO: There are some cases where this transformation is not
+ // profitable, for example:
+ // Add = (C0 + X) * Y + Z.
+ // Maybe the scope of this transformation should be narrowed down.
+ if (containsConstantInAddMulChain(Add))
return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),
SCEV::FlagAnyWrap, Depth + 1),
getMulExpr(LHSC, Add->getOperand(1),
@@ -2941,6 +3010,34 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
return getOrCreateMulExpr(Ops, Flags);
}
+/// Represents an unsigned remainder expression based on unsigned division.
+const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS,
+ const SCEV *RHS) {
+ assert(getEffectiveSCEVType(LHS->getType()) ==
+ getEffectiveSCEVType(RHS->getType()) &&
+ "SCEVURemExpr operand types don't match!");
+
+ // Short-circuit easy cases
+ if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
+ // If constant is one, the result is trivial
+ if (RHSC->getValue()->isOne())
+ return getZero(LHS->getType()); // X urem 1 --> 0
+
+ // If constant is a power of two, fold into a zext(trunc(LHS)).
+ if (RHSC->getAPInt().isPowerOf2()) {
+ Type *FullTy = LHS->getType();
+ Type *TruncTy =
+ IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
+ return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
+ }
+ }
+
+ // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
+ const SCEV *UDiv = getUDivExpr(LHS, RHS);
+ const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
+ return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
+}
+
/// Get a canonical unsigned division expression, or something simpler if
/// possible.
const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
@@ -3056,6 +3153,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
LHS, RHS);
UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
return S;
}
@@ -3236,6 +3334,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator),
O, Operands.size(), L);
UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
}
S->setNoWrapFlags(Flags);
return S;
@@ -3391,6 +3490,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator),
O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
return S;
}
@@ -3492,6 +3592,7 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator),
O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
return S;
}
@@ -3714,7 +3815,6 @@ const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
}
/// Return a SCEV corresponding to -V = -1*V
-///
const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
SCEV::NoWrapFlags Flags) {
if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
@@ -3957,6 +4057,7 @@ void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) {
}
namespace {
+
class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
public:
static const SCEV *rewrite(const SCEV *S, const Loop *L,
@@ -3966,9 +4067,6 @@ public:
return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
}
- SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
- : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
-
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
if (!SE.isLoopInvariant(Expr, L))
Valid = false;
@@ -3986,10 +4084,93 @@ public:
bool isValid() { return Valid; }
private:
+ explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
+ : SCEVRewriteVisitor(SE), L(L) {}
+
+ const Loop *L;
+ bool Valid = true;
+};
+
+/// This class evaluates the compare condition by matching it against the
+/// condition of loop latch. If there is a match we assume a true value
+/// for the condition while building SCEV nodes.
+class SCEVBackedgeConditionFolder
+ : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
+public:
+ static const SCEV *rewrite(const SCEV *S, const Loop *L,
+ ScalarEvolution &SE) {
+ bool IsPosBECond = false;
+ Value *BECond = nullptr;
+ if (BasicBlock *Latch = L->getLoopLatch()) {
+ BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
+ if (BI && BI->isConditional()) {
+ assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
+ "Both outgoing branches should not target same header!");
+ BECond = BI->getCondition();
+ IsPosBECond = BI->getSuccessor(0) == L->getHeader();
+ } else {
+ return S;
+ }
+ }
+ SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
+ return Rewriter.visit(S);
+ }
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ const SCEV *Result = Expr;
+ bool InvariantF = SE.isLoopInvariant(Expr, L);
+
+ if (!InvariantF) {
+ Instruction *I = cast<Instruction>(Expr->getValue());
+ switch (I->getOpcode()) {
+ case Instruction::Select: {
+ SelectInst *SI = cast<SelectInst>(I);
+ Optional<const SCEV *> Res =
+ compareWithBackedgeCondition(SI->getCondition());
+ if (Res.hasValue()) {
+ bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne();
+ Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
+ }
+ break;
+ }
+ default: {
+ Optional<const SCEV *> Res = compareWithBackedgeCondition(I);
+ if (Res.hasValue())
+ Result = Res.getValue();
+ break;
+ }
+ }
+ }
+ return Result;
+ }
+
+private:
+ explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
+ bool IsPosBECond, ScalarEvolution &SE)
+ : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
+ IsPositiveBECond(IsPosBECond) {}
+
+ Optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
+
const Loop *L;
- bool Valid;
+ /// Loop back condition.
+ Value *BackedgeCond = nullptr;
+ /// Set to true if loop back is on positive branch condition.
+ bool IsPositiveBECond;
};
+Optional<const SCEV *>
+SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
+
+ // If value matches the backedge condition for loop latch,
+ // then return a constant evolution node based on loopback
+ // branch taken.
+ if (BackedgeCond == IC)
+ return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
+ : SE.getZero(Type::getInt1Ty(SE.getContext()));
+ return None;
+}
+
class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
public:
static const SCEV *rewrite(const SCEV *S, const Loop *L,
@@ -3999,9 +4180,6 @@ public:
return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
}
- SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
- : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
-
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
// Only allow AddRecExprs for this loop.
if (!SE.isLoopInvariant(Expr, L))
@@ -4015,12 +4193,17 @@ public:
Valid = false;
return Expr;
}
+
bool isValid() { return Valid; }
private:
+ explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
+ : SCEVRewriteVisitor(SE), L(L) {}
+
const Loop *L;
- bool Valid;
+ bool Valid = true;
};
+
} // end anonymous namespace
SCEV::NoWrapFlags
@@ -4028,7 +4211,8 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
if (!AR->isAffine())
return SCEV::FlagAnyWrap;
- typedef OverflowingBinaryOperator OBO;
+ using OBO = OverflowingBinaryOperator;
+
SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap;
if (!AR->hasNoSignedWrap()) {
@@ -4055,6 +4239,7 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
}
namespace {
+
/// Represents an abstract binary operation. This may exist as a
/// normal instruction or constant expression, or may have been
/// derived from an expression tree.
@@ -4062,16 +4247,16 @@ struct BinaryOp {
unsigned Opcode;
Value *LHS;
Value *RHS;
- bool IsNSW;
- bool IsNUW;
+ bool IsNSW = false;
+ bool IsNUW = false;
/// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
/// constant expression.
- Operator *Op;
+ Operator *Op = nullptr;
explicit BinaryOp(Operator *Op)
: Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
- IsNSW(false), IsNUW(false), Op(Op) {
+ Op(Op) {
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
IsNSW = OBO->hasNoSignedWrap();
IsNUW = OBO->hasNoUnsignedWrap();
@@ -4080,11 +4265,10 @@ struct BinaryOp {
explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
bool IsNUW = false)
- : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW),
- Op(nullptr) {}
+ : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
};
-}
+} // end anonymous namespace
/// Try to map \p V into a BinaryOp, and return \c None on failure.
static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
@@ -4101,6 +4285,7 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
case Instruction::Sub:
case Instruction::Mul:
case Instruction::UDiv:
+ case Instruction::URem:
case Instruction::And:
case Instruction::Or:
case Instruction::AShr:
@@ -4145,7 +4330,7 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
if (auto *F = CI->getCalledFunction())
switch (F->getIntrinsicID()) {
case Intrinsic::sadd_with_overflow:
- case Intrinsic::uadd_with_overflow: {
+ case Intrinsic::uadd_with_overflow:
if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT))
return BinaryOp(Instruction::Add, CI->getArgOperand(0),
CI->getArgOperand(1));
@@ -4161,13 +4346,21 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
return BinaryOp(Instruction::Add, CI->getArgOperand(0),
CI->getArgOperand(1), /* IsNSW = */ false,
/* IsNUW*/ true);
- }
-
case Intrinsic::ssub_with_overflow:
case Intrinsic::usub_with_overflow:
- return BinaryOp(Instruction::Sub, CI->getArgOperand(0),
- CI->getArgOperand(1));
+ if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT))
+ return BinaryOp(Instruction::Sub, CI->getArgOperand(0),
+ CI->getArgOperand(1));
+ // The same reasoning as sadd/uadd above.
+ if (F->getIntrinsicID() == Intrinsic::ssub_with_overflow)
+ return BinaryOp(Instruction::Sub, CI->getArgOperand(0),
+ CI->getArgOperand(1), /* IsNSW = */ true,
+ /* IsNUW = */ false);
+ else
+ return BinaryOp(Instruction::Sub, CI->getArgOperand(0),
+ CI->getArgOperand(1), /* IsNSW = */ false,
+ /* IsNUW = */ true);
case Intrinsic::smul_with_overflow:
case Intrinsic::umul_with_overflow:
return BinaryOp(Instruction::Mul, CI->getArgOperand(0),
@@ -4184,28 +4377,27 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
return None;
}
-/// Helper function to createAddRecFromPHIWithCasts. We have a phi
+/// Helper function to createAddRecFromPHIWithCasts. We have a phi
/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
-/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
-/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
+/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
+/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
/// follows one of the following patterns:
/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
/// If the SCEV expression of \p Op conforms with one of the expected patterns
/// we return the type of the truncation operation, and indicate whether the
-/// truncated type should be treated as signed/unsigned by setting
+/// truncated type should be treated as signed/unsigned by setting
/// \p Signed to true/false, respectively.
static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
bool &Signed, ScalarEvolution &SE) {
-
- // The case where Op == SymbolicPHI (that is, with no type conversions on
- // the way) is handled by the regular add recurrence creating logic and
+ // The case where Op == SymbolicPHI (that is, with no type conversions on
+ // the way) is handled by the regular add recurrence creating logic and
// would have already been triggered in createAddRecForPHI. Reaching it here
- // means that createAddRecFromPHI had failed for this PHI before (e.g.,
+ // means that createAddRecFromPHI had failed for this PHI before (e.g.,
// because one of the other operands of the SCEVAddExpr updating this PHI is
- // not invariant).
+ // not invariant).
//
- // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
+ // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
// this case predicates that allow us to prove that Op == SymbolicPHI will
// be added.
if (Op == SymbolicPHI)
@@ -4228,7 +4420,7 @@ static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
const SCEV *X = Trunc->getOperand();
if (X != SymbolicPHI)
return nullptr;
- Signed = SExt ? true : false;
+ Signed = SExt != nullptr;
return Trunc->getType();
}
@@ -4257,7 +4449,7 @@ static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
// and call this function with %SymbolicPHI = %X.
//
-// The analysis will find that the value coming around the backedge has
+// The analysis will find that the value coming around the backedge has
// the following SCEV:
// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
// Upon concluding that this matches the desired pattern, the function
@@ -4270,21 +4462,21 @@ static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
// under the predicates {P1,P2,P3}.
// This predicated rewrite will be cached in PredicatedSCEVRewrites:
-// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
+// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
//
// TODO's:
//
// 1) Extend the Induction descriptor to also support inductions that involve
-// casts: When needed (namely, when we are called in the context of the
-// vectorizer induction analysis), a Set of cast instructions will be
+// casts: When needed (namely, when we are called in the context of the
+// vectorizer induction analysis), a Set of cast instructions will be
// populated by this method, and provided back to isInductionPHI. This is
// needed to allow the vectorizer to properly record them to be ignored by
// the cost model and to avoid vectorizing them (otherwise these casts,
-// which are redundant under the runtime overflow checks, will be
-// vectorized, which can be costly).
+// which are redundant under the runtime overflow checks, will be
+// vectorized, which can be costly).
//
// 2) Support additional induction/PHISCEV patterns: We also want to support
-// inductions where the sext-trunc / zext-trunc operations (partly) occur
+// inductions where the sext-trunc / zext-trunc operations (partly) occur
// after the induction update operation (the induction increment):
//
// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
@@ -4294,17 +4486,16 @@ static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
// which correspond to a phi->trunc->add->sext/zext->phi update chain.
//
// 3) Outline common code with createAddRecFromPHI to avoid duplication.
-//
Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
SmallVector<const SCEVPredicate *, 3> Predicates;
- // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
+ // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
// return an AddRec expression under some predicate.
-
+
auto *PN = cast<PHINode>(SymbolicPHI->getValue());
const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
- assert (L && "Expecting an integer loop header phi");
+ assert(L && "Expecting an integer loop header phi");
// The loop may have multiple entrances or multiple exits; we can analyze
// this phi as an addrec if it has a unique entry value and a unique
@@ -4339,12 +4530,12 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI
return None;
// If there is a single occurrence of the symbolic value, possibly
- // casted, replace it with a recurrence.
+ // casted, replace it with a recurrence.
unsigned FoundIndex = Add->getNumOperands();
Type *TruncTy = nullptr;
bool Signed;
for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
- if ((TruncTy =
+ if ((TruncTy =
isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
if (FoundIndex == e) {
FoundIndex = i;
@@ -4366,77 +4557,122 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI
if (!isLoopInvariant(Accum, L))
return None;
-
- // *** Part2: Create the predicates
+ // *** Part2: Create the predicates
// Analysis was successful: we have a phi-with-cast pattern for which we
// can return an AddRec expression under the following predicates:
//
// P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
// fits within the truncated type (does not overflow) for i = 0 to n-1.
- // P2: An Equal predicate that guarantees that
+ // P2: An Equal predicate that guarantees that
// Start = (Ext ix (Trunc iy (Start) to ix) to iy)
- // P3: An Equal predicate that guarantees that
+ // P3: An Equal predicate that guarantees that
// Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
//
- // As we next prove, the above predicates guarantee that:
+ // As we next prove, the above predicates guarantee that:
// Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
//
//
// More formally, we want to prove that:
- // Expr(i+1) = Start + (i+1) * Accum
- // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
+ // Expr(i+1) = Start + (i+1) * Accum
+ // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
//
// Given that:
- // 1) Expr(0) = Start
- // 2) Expr(1) = Start + Accum
+ // 1) Expr(0) = Start
+ // 2) Expr(1) = Start + Accum
// = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
// 3) Induction hypothesis (step i):
- // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
+ // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
//
// Proof:
// Expr(i+1) =
// = Start + (i+1)*Accum
// = (Start + i*Accum) + Accum
- // = Expr(i) + Accum
- // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
+ // = Expr(i) + Accum
+ // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
// :: from step i
//
- // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
+ // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
//
// = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
// + (Ext ix (Trunc iy (Accum) to ix) to iy)
// + Accum :: from P3
//
- // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
+ // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
// + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
//
// = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
- // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
+ // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
//
// By induction, the same applies to all iterations 1<=i<n:
//
-
+
// Create a truncated addrec for which we will add a no overflow check (P1).
const SCEV *StartVal = getSCEV(StartValueV);
- const SCEV *PHISCEV =
+ const SCEV *PHISCEV =
getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
- getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
- const auto *AR = cast<SCEVAddRecExpr>(PHISCEV);
+ getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
- SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
- Signed ? SCEVWrapPredicate::IncrementNSSW
- : SCEVWrapPredicate::IncrementNUSW;
- const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
- Predicates.push_back(AddRecPred);
+ // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
+ // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
+ // will be constant.
+ //
+ // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
+ // add P1.
+ if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
+ SCEVWrapPredicate::IncrementWrapFlags AddedFlags =
+ Signed ? SCEVWrapPredicate::IncrementNSSW
+ : SCEVWrapPredicate::IncrementNUSW;
+ const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
+ Predicates.push_back(AddRecPred);
+ }
// Create the Equal Predicates P2,P3:
- auto AppendPredicate = [&](const SCEV *Expr) -> void {
- assert (isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
+
+ // It is possible that the predicates P2 and/or P3 are computable at
+ // compile time due to StartVal and/or Accum being constants.
+ // If either one is, then we can check that now and escape if either P2
+ // or P3 is false.
+
+ // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
+ // for each of StartVal and Accum
+ auto getExtendedExpr = [&](const SCEV *Expr,
+ bool CreateSignExtend) -> const SCEV * {
+ assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
const SCEV *ExtendedExpr =
- Signed ? getSignExtendExpr(TruncatedExpr, Expr->getType())
- : getZeroExtendExpr(TruncatedExpr, Expr->getType());
+ CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
+ : getZeroExtendExpr(TruncatedExpr, Expr->getType());
+ return ExtendedExpr;
+ };
+
+ // Given:
+ // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
+ // = getExtendedExpr(Expr)
+ // Determine whether the predicate P: Expr == ExtendedExpr
+ // is known to be false at compile time
+ auto PredIsKnownFalse = [&](const SCEV *Expr,
+ const SCEV *ExtendedExpr) -> bool {
+ return Expr != ExtendedExpr &&
+ isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
+ };
+
+ const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
+ if (PredIsKnownFalse(StartVal, StartExtended)) {
+ DEBUG(dbgs() << "P2 is compile-time false\n";);
+ return None;
+ }
+
+ // The Step is always Signed (because the overflow checks are either
+ // NSSW or NUSW)
+ const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
+ if (PredIsKnownFalse(Accum, AccumExtended)) {
+ DEBUG(dbgs() << "P3 is compile-time false\n";);
+ return None;
+ }
+
+ auto AppendPredicate = [&](const SCEV *Expr,
+ const SCEV *ExtendedExpr) -> void {
if (Expr != ExtendedExpr &&
!isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
@@ -4444,14 +4680,14 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI
Predicates.push_back(Pred);
}
};
-
- AppendPredicate(StartVal);
- AppendPredicate(Accum);
-
+
+ AppendPredicate(StartVal, StartExtended);
+ AppendPredicate(Accum, AccumExtended);
+
// *** Part3: Predicates are ready. Now go ahead and create the new addrec in
// which the casts had been folded away. The caller can rewrite SymbolicPHI
// into NewAR if it will also add the runtime overflow checks specified in
- // Predicates.
+ // Predicates.
auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
@@ -4463,7 +4699,6 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI
Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
-
auto *PN = cast<PHINode>(SymbolicPHI->getValue());
const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
if (!L)
@@ -4475,7 +4710,7 @@ ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
I->second;
// Analysis was done before and failed to create an AddRec:
- if (Rewrite.first == SymbolicPHI)
+ if (Rewrite.first == SymbolicPHI)
return None;
// Analysis was done before and succeeded to create an AddRec under
// a predicate:
@@ -4497,6 +4732,30 @@ ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {
return Rewrite;
}
+// FIXME: This utility is currently required because the Rewriter currently
+// does not rewrite this expression:
+// {0, +, (sext ix (trunc iy to ix) to iy)}
+// into {0, +, %step},
+// even when the following Equal predicate exists:
+// "%step == (sext ix (trunc iy to ix) to iy)".
+bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
+ const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
+ if (AR1 == AR2)
+ return true;
+
+ auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
+ if (Expr1 != Expr2 && !Preds.implies(SE.getEqualPredicate(Expr1, Expr2)) &&
+ !Preds.implies(SE.getEqualPredicate(Expr2, Expr1)))
+ return false;
+ return true;
+ };
+
+ if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
+ !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
+ return false;
+ return true;
+}
+
/// A helper function for createAddRecFromPHI to handle simple cases.
///
/// This function tries to find an AddRec expression for the simplest (yet most
@@ -4612,7 +4871,8 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
SmallVector<const SCEV *, 8> Ops;
for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
if (i != FoundIndex)
- Ops.push_back(Add->getOperand(i));
+ Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
+ L, *this));
const SCEV *Accum = getAddExpr(Ops);
// This is not a valid addrec if the step amount is varying each
@@ -5599,7 +5859,7 @@ bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
ScalarEvolution::LoopProperties
ScalarEvolution::getLoopProperties(const Loop *L) {
- typedef ScalarEvolution::LoopProperties LoopProperties;
+ using LoopProperties = ScalarEvolution::LoopProperties;
auto Itr = LoopPropertiesCache.find(L);
if (Itr == LoopPropertiesCache.end()) {
@@ -5735,6 +5995,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
}
case Instruction::UDiv:
return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS));
+ case Instruction::URem:
+ return getURemExpr(getSCEV(BO->LHS), getSCEV(BO->RHS));
case Instruction::Sub: {
SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
if (BO->Op)
@@ -5886,7 +6148,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
}
break;
- case Instruction::AShr:
+ case Instruction::AShr: {
// AShr X, C, where C is a constant.
ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
if (!CI)
@@ -5938,6 +6200,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
}
break;
}
+ }
}
switch (U->getOpcode()) {
@@ -5948,6 +6211,21 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
case Instruction::SExt:
+ if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) {
+ // The NSW flag of a subtract does not always survive the conversion to
+ // A + (-1)*B. By pushing sign extension onto its operands we are much
+ // more likely to preserve NSW and allow later AddRec optimisations.
+ //
+ // NOTE: This is effectively duplicating this logic from getSignExtend:
+ // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
+ // but by that point the NSW information has potentially been lost.
+ if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
+ Type *Ty = U->getType();
+ auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
+ auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
+ return getMinusSCEV(V1, V2, SCEV::FlagNSW);
+ }
+ }
return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
case Instruction::BitCast:
@@ -5987,8 +6265,6 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
return getUnknown(V);
}
-
-
//===----------------------------------------------------------------------===//
// Iteration Count Computation Code
//
@@ -6177,11 +6453,9 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
SmallVector<Instruction *, 16> Worklist;
PushLoopPHIs(L, Worklist);
- SmallPtrSet<Instruction *, 8> Visited;
+ SmallPtrSet<Instruction *, 8> Discovered;
while (!Worklist.empty()) {
Instruction *I = Worklist.pop_back_val();
- if (!Visited.insert(I).second)
- continue;
ValueExprMapType::iterator It =
ValueExprMap.find_as(static_cast<Value *>(I));
@@ -6202,7 +6476,31 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
ConstantEvolutionLoopExitValue.erase(PN);
}
- PushDefUseChildren(I, Worklist);
+ // Since we don't need to invalidate anything for correctness and we're
+ // only invalidating to make SCEV's results more precise, we get to stop
+ // early to avoid invalidating too much. This is especially important in
+ // cases like:
+ //
+ // %v = f(pn0, pn1) // pn0 and pn1 used through some other phi node
+ // loop0:
+ // %pn0 = phi
+ // ...
+ // loop1:
+ // %pn1 = phi
+ // ...
+ //
+ // where both loop0 and loop1's backedge taken count uses the SCEV
+ // expression for %v. If we don't have the early stop below then in cases
+ // like the above, getBackedgeTakenInfo(loop1) will clear out the trip
+ // count for loop0 and getBackedgeTakenInfo(loop0) will clear out the trip
+ // count for loop1, effectively nullifying SCEV's trip count cache.
+ for (auto *U : I->users())
+ if (auto *I = dyn_cast<Instruction>(U)) {
+ auto *LoopForUser = LI.getLoopFor(I->getParent());
+ if (LoopForUser && L->contains(LoopForUser) &&
+ Discovered.insert(I).second)
+ Worklist.push_back(I);
+ }
}
}
@@ -6217,7 +6515,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
void ScalarEvolution::forgetLoop(const Loop *L) {
// Drop any stored trip count value.
auto RemoveLoopFromBackedgeMap =
- [L](DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
+ [](DenseMap<const Loop *, BackedgeTakenInfo> &Map, const Loop *L) {
auto BTCPos = Map.find(L);
if (BTCPos != Map.end()) {
BTCPos->second.clear();
@@ -6225,47 +6523,59 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
}
};
- RemoveLoopFromBackedgeMap(BackedgeTakenCounts);
- RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts);
+ SmallVector<const Loop *, 16> LoopWorklist(1, L);
+ SmallVector<Instruction *, 32> Worklist;
+ SmallPtrSet<Instruction *, 16> Visited;
- // Drop information about predicated SCEV rewrites for this loop.
- for (auto I = PredicatedSCEVRewrites.begin();
- I != PredicatedSCEVRewrites.end();) {
- std::pair<const SCEV *, const Loop *> Entry = I->first;
- if (Entry.second == L)
- PredicatedSCEVRewrites.erase(I++);
- else
- ++I;
- }
+ // Iterate over all the loops and sub-loops to drop SCEV information.
+ while (!LoopWorklist.empty()) {
+ auto *CurrL = LoopWorklist.pop_back_val();
- // Drop information about expressions based on loop-header PHIs.
- SmallVector<Instruction *, 16> Worklist;
- PushLoopPHIs(L, Worklist);
+ RemoveLoopFromBackedgeMap(BackedgeTakenCounts, CurrL);
+ RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts, CurrL);
- SmallPtrSet<Instruction *, 8> Visited;
- while (!Worklist.empty()) {
- Instruction *I = Worklist.pop_back_val();
- if (!Visited.insert(I).second)
- continue;
+ // Drop information about predicated SCEV rewrites for this loop.
+ for (auto I = PredicatedSCEVRewrites.begin();
+ I != PredicatedSCEVRewrites.end();) {
+ std::pair<const SCEV *, const Loop *> Entry = I->first;
+ if (Entry.second == CurrL)
+ PredicatedSCEVRewrites.erase(I++);
+ else
+ ++I;
+ }
- ValueExprMapType::iterator It =
- ValueExprMap.find_as(static_cast<Value *>(I));
- if (It != ValueExprMap.end()) {
- eraseValueFromMap(It->first);
- forgetMemoizedResults(It->second);
- if (PHINode *PN = dyn_cast<PHINode>(I))
- ConstantEvolutionLoopExitValue.erase(PN);
+ auto LoopUsersItr = LoopUsers.find(CurrL);
+ if (LoopUsersItr != LoopUsers.end()) {
+ for (auto *S : LoopUsersItr->second)
+ forgetMemoizedResults(S);
+ LoopUsers.erase(LoopUsersItr);
}
- PushDefUseChildren(I, Worklist);
- }
+ // Drop information about expressions based on loop-header PHIs.
+ PushLoopPHIs(CurrL, Worklist);
- // Forget all contained loops too, to avoid dangling entries in the
- // ValuesAtScopes map.
- for (Loop *I : *L)
- forgetLoop(I);
+ while (!Worklist.empty()) {
+ Instruction *I = Worklist.pop_back_val();
+ if (!Visited.insert(I).second)
+ continue;
- LoopPropertiesCache.erase(L);
+ ValueExprMapType::iterator It =
+ ValueExprMap.find_as(static_cast<Value *>(I));
+ if (It != ValueExprMap.end()) {
+ eraseValueFromMap(It->first);
+ forgetMemoizedResults(It->second);
+ if (PHINode *PN = dyn_cast<PHINode>(I))
+ ConstantEvolutionLoopExitValue.erase(PN);
+ }
+
+ PushDefUseChildren(I, Worklist);
+ }
+
+ LoopPropertiesCache.erase(CurrL);
+ // Forget all contained loops too, to avoid dangling entries in the
+ // ValuesAtScopes map.
+ LoopWorklist.append(CurrL->begin(), CurrL->end());
+ }
}
void ScalarEvolution::forgetValue(Value *V) {
@@ -6377,7 +6687,7 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
}
ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
- : ExactNotTaken(E), MaxNotTaken(E), MaxOrZero(false) {
+ : ExactNotTaken(E), MaxNotTaken(E) {
assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
isa<SCEVConstant>(MaxNotTaken)) &&
"No point in having a non-constant max backedge taken count!");
@@ -6422,7 +6732,8 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
&&ExitCounts,
bool Complete, const SCEV *MaxCount, bool MaxOrZero)
: MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) {
- typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo;
+ using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
+
ExitNotTaken.reserve(ExitCounts.size());
std::transform(
ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken),
@@ -6454,7 +6765,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
SmallVector<BasicBlock *, 8> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
- typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo;
+ using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
SmallVector<EdgeExitInfo, 4> ExitCounts;
bool CouldComputeBECount = true;
@@ -6521,8 +6832,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
- bool AllowPredicates) {
-
+ bool AllowPredicates) {
// Okay, we've chosen an exiting block. See what condition causes us to exit
// at this block and remember the exit block and whether all other targets
// lead to the loop header.
@@ -6785,19 +7095,19 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
BasicBlock *FBB,
bool ControlsExit,
bool AllowPredicates) {
-
// If the condition was exit on true, convert the condition to exit on false
- ICmpInst::Predicate Cond;
+ ICmpInst::Predicate Pred;
if (!L->contains(FBB))
- Cond = ExitCond->getPredicate();
+ Pred = ExitCond->getPredicate();
else
- Cond = ExitCond->getInversePredicate();
+ Pred = ExitCond->getInversePredicate();
+ const ICmpInst::Predicate OriginalPred = Pred;
// Handle common loops like: for (X = "string"; *X; ++X)
if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
ExitLimit ItCnt =
- computeLoadConstantCompareExitLimit(LI, RHS, L, Cond);
+ computeLoadConstantCompareExitLimit(LI, RHS, L, Pred);
if (ItCnt.hasAnyInfo())
return ItCnt;
}
@@ -6814,11 +7124,11 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
// If there is a loop-invariant, force it into the RHS.
std::swap(LHS, RHS);
- Cond = ICmpInst::getSwappedPredicate(Cond);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
}
// Simplify the operands before analyzing them.
- (void)SimplifyICmpOperands(Cond, LHS, RHS);
+ (void)SimplifyICmpOperands(Pred, LHS, RHS);
// If we have a comparison of a chrec against a constant, try to use value
// ranges to answer this query.
@@ -6827,13 +7137,13 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
if (AddRec->getLoop() == L) {
// Form the constant range.
ConstantRange CompRange =
- ConstantRange::makeExactICmpRegion(Cond, RHSC->getAPInt());
+ ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
}
- switch (Cond) {
+ switch (Pred) {
case ICmpInst::ICMP_NE: { // while (X != Y)
// Convert to: while (X-Y != 0)
ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit,
@@ -6849,7 +7159,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
}
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_ULT: { // while (X < Y)
- bool IsSigned = Cond == ICmpInst::ICMP_SLT;
+ bool IsSigned = Pred == ICmpInst::ICMP_SLT;
ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit,
AllowPredicates);
if (EL.hasAnyInfo()) return EL;
@@ -6857,7 +7167,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
}
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_UGT: { // while (X > Y)
- bool IsSigned = Cond == ICmpInst::ICMP_SGT;
+ bool IsSigned = Pred == ICmpInst::ICMP_SGT;
ExitLimit EL =
howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit,
AllowPredicates);
@@ -6875,7 +7185,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
return ExhaustiveCount;
return computeShiftCompareExitLimit(ExitCond->getOperand(0),
- ExitCond->getOperand(1), L, Cond);
+ ExitCond->getOperand(1), L, OriginalPred);
}
ScalarEvolution::ExitLimit
@@ -6920,7 +7230,6 @@ ScalarEvolution::computeLoadConstantCompareExitLimit(
Constant *RHS,
const Loop *L,
ICmpInst::Predicate predicate) {
-
if (LI->isVolatile()) return getCouldNotCompute();
// Check to see if the loaded pointer is a getelementptr of a global.
@@ -7333,8 +7642,8 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
Value *BEValue = PN->getIncomingValueForBlock(Latch);
// Execute the loop symbolically to determine the exit value.
- if (BEs.getActiveBits() >= 32)
- return RetVal = nullptr; // More than 2^32-1 iterations?? Not doing it!
+ assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
+ "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
unsigned NumIterations = BEs.getZExtValue(); // must be in range
unsigned IterationNum = 0;
@@ -7839,7 +8148,6 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
/// Find the roots of the quadratic equation for the given quadratic chrec
/// {L,+,M,+,N}. This returns either the two roots (which might be the same) or
/// two SCEVCouldNotCompute objects.
-///
static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>>
SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
@@ -8080,7 +8388,6 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
/// expressions are equal, however for the purposes of looking for a condition
/// guarding a loop, it can be useful to be a little more general, since a
/// front-end may have replicated the controlling expression.
-///
static bool HasSameValue(const SCEV *A, const SCEV *B) {
// Quick check to see if they are the same SCEV.
if (A == B) return true;
@@ -8527,7 +8834,6 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges(
bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
const SCEV *LHS,
const SCEV *RHS) {
-
// Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
// Return Y via OutY.
auto MatchBinaryAddToConst =
@@ -8693,7 +8999,6 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
DTN != HeaderDTN; DTN = DTN->getIDom()) {
-
assert(DTN && "should reach the loop header before reaching the root!");
BasicBlock *BB = DTN->getBlock();
@@ -9116,7 +9421,6 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
getNotSCEV(FoundLHS));
}
-
/// If Expr computes ~A, return A else return nullptr
static const SCEV *MatchNotExpr(const SCEV *Expr) {
const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
@@ -9132,7 +9436,6 @@ static const SCEV *MatchNotExpr(const SCEV *Expr) {
return AddRHS->getOperand(1);
}
-
/// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values?
template<typename MaxExprType>
static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr,
@@ -9143,7 +9446,6 @@ static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr,
return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end();
}
-
/// Is MaybeMinExpr an SMin or UMin of Candidate and some other values?
template<typename MaxExprType>
static bool IsMinConsistingOf(ScalarEvolution &SE,
@@ -9159,7 +9461,6 @@ static bool IsMinConsistingOf(ScalarEvolution &SE,
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
-
// If both sides are affine addrecs for the same loop, with equal
// steps, and we know the recurrences don't wrap, then we only
// need to check the predicate on the starting values.
@@ -9295,7 +9596,9 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
} else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
Value *LL, *LR;
// FIXME: Once we have SDiv implemented, we can get rid of this matching.
+
using namespace llvm::PatternMatch;
+
if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
// Rules for division.
// We are going to perform some comparisons with Denominator and its
@@ -9510,14 +9813,54 @@ const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,
return getUDivExpr(Delta, Step);
}
+const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
+ const SCEV *Stride,
+ const SCEV *End,
+ unsigned BitWidth,
+ bool IsSigned) {
+
+ assert(!isKnownNonPositive(Stride) &&
+ "Stride is expected strictly positive!");
+ // Calculate the maximum backedge count based on the range of values
+ // permitted by Start, End, and Stride.
+ const SCEV *MaxBECount;
+ APInt MinStart =
+ IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
+
+ APInt StrideForMaxBECount =
+ IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
+
+ // We already know that the stride is positive, so we paper over conservatism
+ // in our range computation by forcing StrideForMaxBECount to be at least one.
+ // In theory this is unnecessary, but we expect MaxBECount to be a
+ // SCEVConstant, and (udiv <constant> 0) is not constant folded by SCEV (there
+ // is nothing to constant fold it to).
+ APInt One(BitWidth, 1, IsSigned);
+ StrideForMaxBECount = APIntOps::smax(One, StrideForMaxBECount);
+
+ APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
+ : APInt::getMaxValue(BitWidth);
+ APInt Limit = MaxValue - (StrideForMaxBECount - 1);
+
+ // Although End can be a MAX expression we estimate MaxEnd considering only
+ // the case End = RHS of the loop termination condition. This is safe because
+ // in the other case (End - Start) is zero, leading to a zero maximum backedge
+ // taken count.
+ APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
+ : APIntOps::umin(getUnsignedRangeMax(End), Limit);
+
+ MaxBECount = computeBECount(getConstant(MaxEnd - MinStart) /* Delta */,
+ getConstant(StrideForMaxBECount) /* Step */,
+ false /* Equality */);
+
+ return MaxBECount;
+}
+
ScalarEvolution::ExitLimit
ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned,
bool ControlsExit, bool AllowPredicates) {
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
- // We handle only IV < Invariant
- if (!isLoopInvariant(RHS, L))
- return getCouldNotCompute();
const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
bool PredicatedIV = false;
@@ -9588,7 +9931,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
if (PredicatedIV || !NoWrap || isKnownNonPositive(Stride) ||
!loopHasNoSideEffects(L))
return getCouldNotCompute();
-
} else if (!Stride->isOne() &&
doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap))
// Avoid proven overflow cases: this will ensure that the backedge taken
@@ -9601,6 +9943,17 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
: ICmpInst::ICMP_ULT;
const SCEV *Start = IV->getStart();
const SCEV *End = RHS;
+ // When the RHS is not invariant, we do not know the end bound of the loop and
+ // cannot calculate the ExactBECount needed by ExitLimit. However, we can
+ // calculate the MaxBECount, given the start, stride and max value for the end
+ // bound of the loop (RHS), and the fact that IV does not overflow (which is
+ // checked above).
+ if (!isLoopInvariant(RHS, L)) {
+ const SCEV *MaxBECount = computeMaxBECountForLT(
+ Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
+ return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
+ false /*MaxOrZero*/, Predicates);
+ }
// If the backedge is taken at least once, then it will be taken
// (End-Start)/Stride times (rounded up to a multiple of Stride), where Start
// is the LHS value of the less-than comparison the first time it is evaluated
@@ -9633,37 +9986,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
MaxBECount = BECountIfBackedgeTaken;
MaxOrZero = true;
} else {
- // Calculate the maximum backedge count based on the range of values
- // permitted by Start, End, and Stride.
- APInt MinStart = IsSigned ? getSignedRangeMin(Start)
- : getUnsignedRangeMin(Start);
-
- unsigned BitWidth = getTypeSizeInBits(LHS->getType());
-
- APInt StrideForMaxBECount;
-
- if (PositiveStride)
- StrideForMaxBECount =
- IsSigned ? getSignedRangeMin(Stride)
- : getUnsignedRangeMin(Stride);
- else
- // Using a stride of 1 is safe when computing max backedge taken count for
- // a loop with unknown stride.
- StrideForMaxBECount = APInt(BitWidth, 1, IsSigned);
-
- APInt Limit =
- IsSigned ? APInt::getSignedMaxValue(BitWidth) - (StrideForMaxBECount - 1)
- : APInt::getMaxValue(BitWidth) - (StrideForMaxBECount - 1);
-
- // Although End can be a MAX expression we estimate MaxEnd considering only
- // the case End = RHS. This is safe because in the other case (End - Start)
- // is zero, leading to a zero maximum backedge taken count.
- APInt MaxEnd =
- IsSigned ? APIntOps::smin(getSignedRangeMax(RHS), Limit)
- : APIntOps::umin(getUnsignedRangeMax(RHS), Limit);
-
- MaxBECount = computeBECount(getConstant(MaxEnd - MinStart),
- getConstant(StrideForMaxBECount), false);
+ MaxBECount = computeMaxBECountForLT(
+ Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
}
if (isa<SCEVCouldNotCompute>(MaxBECount) &&
@@ -9874,6 +10198,7 @@ static inline bool containsUndefs(const SCEV *S) {
}
namespace {
+
// Collect all steps of SCEV expressions.
struct SCEVCollectStrides {
ScalarEvolution &SE;
@@ -9887,6 +10212,7 @@ struct SCEVCollectStrides {
Strides.push_back(AR->getStepRecurrence(SE));
return true;
}
+
bool isDone() const { return false; }
};
@@ -9894,8 +10220,7 @@ struct SCEVCollectStrides {
struct SCEVCollectTerms {
SmallVectorImpl<const SCEV *> &Terms;
- SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T)
- : Terms(T) {}
+ SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T) : Terms(T) {}
bool follow(const SCEV *S) {
if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) ||
@@ -9910,6 +10235,7 @@ struct SCEVCollectTerms {
// Keep looking.
return true;
}
+
bool isDone() const { return false; }
};
@@ -9918,7 +10244,7 @@ struct SCEVHasAddRec {
bool &ContainsAddRec;
SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) {
- ContainsAddRec = false;
+ ContainsAddRec = false;
}
bool follow(const SCEV *S) {
@@ -9932,6 +10258,7 @@ struct SCEVHasAddRec {
// Keep looking.
return true;
}
+
bool isDone() const { return false; }
};
@@ -9985,9 +10312,11 @@ struct SCEVCollectAddRecMultiplies {
// Keep looking.
return true;
}
+
bool isDone() const { return false; }
};
-}
+
+} // end anonymous namespace
/// Find parametric terms in this SCEVAddRecExpr. We first for parameters in
/// two places:
@@ -10066,7 +10395,6 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE,
return true;
}
-
// Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter.
static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
for (const SCEV *T : Terms)
@@ -10181,7 +10509,6 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
void ScalarEvolution::computeAccessFunctions(
const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts,
SmallVectorImpl<const SCEV *> &Sizes) {
-
// Early exit in case this SCEV is not an affine multivariate function.
if (Sizes.empty())
return;
@@ -10285,7 +10612,6 @@ void ScalarEvolution::computeAccessFunctions(
/// DelinearizationPass that walks through all loads and stores of a function
/// asking for the SCEV of the memory access with respect to all enclosing
/// loops, calling SCEV->delinearize on that and printing the results.
-
void ScalarEvolution::delinearize(const SCEV *Expr,
SmallVectorImpl<const SCEV *> &Subscripts,
SmallVectorImpl<const SCEV *> &Sizes,
@@ -10374,11 +10700,8 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI,
AssumptionCache &AC, DominatorTree &DT,
LoopInfo &LI)
: F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
- CouldNotCompute(new SCEVCouldNotCompute()),
- WalkingBEDominatingConds(false), ProvingSplitPredicate(false),
- ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64),
- FirstUnknown(nullptr) {
-
+ CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
+ LoopDispositions(64), BlockDispositions(64) {
// To use guards for proving predicates, we need to scan every instruction in
// relevant basic blocks, and not just terminators. Doing this is a waste of
// time if the IR does not actually contain any calls to
@@ -10399,7 +10722,6 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
ValueExprMap(std::move(Arg.ValueExprMap)),
PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
- WalkingBEDominatingConds(false), ProvingSplitPredicate(false),
MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)),
BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
PredicatedBackedgeTakenCounts(
@@ -10415,6 +10737,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
UniquePreds(std::move(Arg.UniquePreds)),
SCEVAllocator(std::move(Arg.SCEVAllocator)),
+ LoopUsers(std::move(Arg.LoopUsers)),
PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
FirstUnknown(Arg.FirstUnknown) {
Arg.FirstUnknown = nullptr;
@@ -10647,9 +10970,11 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
if (!L)
return LoopVariant;
- // This recurrence is variant w.r.t. L if L contains AR's loop.
- if (L->contains(AR->getLoop()))
+ // Everything that is not defined at loop entry is variant.
+ if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
return LoopVariant;
+ assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
+ " dominate the contained loop's header?");
// This recurrence is invariant w.r.t. L if AR's loop contains L.
if (AR->getLoop()->contains(L))
@@ -10806,7 +11131,16 @@ bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
}
-void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
+bool ScalarEvolution::ExitLimit::hasOperand(const SCEV *S) const {
+ auto IsS = [&](const SCEV *X) { return S == X; };
+ auto ContainsS = [&](const SCEV *X) {
+ return !isa<SCEVCouldNotCompute>(X) && SCEVExprContains(X, IsS);
+ };
+ return ContainsS(ExactNotTaken) || ContainsS(MaxNotTaken);
+}
+
+void
+ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
ValuesAtScopes.erase(S);
LoopDispositions.erase(S);
BlockDispositions.erase(S);
@@ -10816,7 +11150,7 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
HasRecMap.erase(S);
MinTrailingZerosCache.erase(S);
- for (auto I = PredicatedSCEVRewrites.begin();
+ for (auto I = PredicatedSCEVRewrites.begin();
I != PredicatedSCEVRewrites.end();) {
std::pair<const SCEV *, const Loop *> Entry = I->first;
if (Entry.first == S)
@@ -10841,6 +11175,25 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts);
}
+void ScalarEvolution::addToLoopUseLists(const SCEV *S) {
+ struct FindUsedLoops {
+ SmallPtrSet<const Loop *, 8> LoopsUsed;
+ bool follow(const SCEV *S) {
+ if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
+ LoopsUsed.insert(AR->getLoop());
+ return true;
+ }
+
+ bool isDone() const { return false; }
+ };
+
+ FindUsedLoops F;
+ SCEVTraversal<FindUsedLoops>(F).visitAll(S);
+
+ for (auto *L : F.LoopsUsed)
+ LoopUsers[L].push_back(S);
+}
+
void ScalarEvolution::verify() const {
ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
ScalarEvolution SE2(F, TLI, AC, DT, LI);
@@ -10849,9 +11202,12 @@ void ScalarEvolution::verify() const {
// Map's SCEV expressions from one ScalarEvolution "universe" to another.
struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
+ SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
+
const SCEV *visitConstant(const SCEVConstant *Constant) {
return SE.getConstant(Constant->getAPInt());
}
+
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
return SE.getUnknown(Expr->getValue());
}
@@ -10859,7 +11215,6 @@ void ScalarEvolution::verify() const {
const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
return SE.getCouldNotCompute();
}
- SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
};
SCEVMapper SCM(SE2);
@@ -10948,6 +11303,7 @@ INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution",
"Scalar Evolution Analysis", false, true)
+
char ScalarEvolutionWrapperPass::ID = 0;
ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) {
@@ -11023,6 +11379,7 @@ namespace {
class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
public:
+
/// Rewrites \p S in the context of a loop L and the SCEV predication
/// infrastructure.
///
@@ -11038,11 +11395,6 @@ public:
return Rewriter.visit(S);
}
- SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
- SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
- SCEVUnionPredicate *Pred)
- : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
-
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
if (Pred) {
auto ExprPreds = Pred->getPredicatesForExpr(Expr);
@@ -11087,6 +11439,11 @@ public:
}
private:
+ explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
+ SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
+ SCEVUnionPredicate *Pred)
+ : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
+
bool addOverflowAssumption(const SCEVPredicate *P) {
if (!NewPreds) {
// Check if we've already made this assumption.
@@ -11103,10 +11460,10 @@ private:
}
// If \p Expr represents a PHINode, we try to see if it can be represented
- // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
+ // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
// to add this predicate as a runtime overflow check, we return the AddRec.
- // If \p Expr does not meet these conditions (is not a PHI node, or we
- // couldn't create an AddRec for it, or couldn't add the predicate), we just
+ // If \p Expr does not meet these conditions (is not a PHI node, or we
+ // couldn't create an AddRec for it, or couldn't add the predicate), we just
// return \p Expr.
const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
if (!isa<PHINode>(Expr->getValue()))
@@ -11121,11 +11478,12 @@ private:
}
return PredicatedRewrite->first;
}
-
+
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
SCEVUnionPredicate *Pred;
const Loop *L;
};
+
} // end anonymous namespace
const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
@@ -11136,7 +11494,6 @@ const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
const SCEV *S, const Loop *L,
SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
-
SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
@@ -11292,7 +11649,7 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) {
PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
Loop &L)
- : SE(SE), L(L), Generation(0), BackedgeCount(nullptr) {}
+ : SE(SE), L(L) {}
const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
const SCEV *Expr = SE.getSCEV(V);