aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopFlatten.cpp')
-rw-r--r--llvm/lib/Transforms/Scalar/LoopFlatten.cpp97
1 files changed, 15 insertions, 82 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index 7d9ce8d35e0b..edc8a4956dd1 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -65,11 +65,8 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/LoopPassManager.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -318,12 +315,12 @@ static bool verifyTripCount(Value *RHS, Loop *L,
return false;
}
- // The Extend=false flag is used for getTripCountFromExitCount as we want
- // to verify and match it with the pattern matched tripcount. Please note
- // that overflow checks are performed in checkOverflow, but are first tried
- // to avoid by widening the IV.
+ // Evaluating in the trip count's type can not overflow here as the overflow
+ // checks are performed in checkOverflow, but are first tried to avoid by
+ // widening the IV.
const SCEV *SCEVTripCount =
- SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false);
+ SE->getTripCountFromExitCount(BackedgeTakenCount,
+ BackedgeTakenCount->getType(), L);
const SCEV *SCEVRHS = SE->getSCEV(RHS);
if (SCEVRHS == SCEVTripCount)
@@ -336,7 +333,8 @@ static bool verifyTripCount(Value *RHS, Loop *L,
// Find the extended backedge taken count and extended trip count using
// SCEV. One of these should now match the RHS of the compare.
BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
- SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
+ SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt,
+ RHS->getType(), L);
if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
@@ -918,20 +916,6 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
}
-bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
- AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U,
- MemorySSAUpdater *MSSAU) {
- bool Changed = false;
- for (Loop *InnerLoop : LN.getLoops()) {
- auto *OuterLoop = InnerLoop->getParentLoop();
- if (!OuterLoop)
- continue;
- FlattenInfo FI(OuterLoop, InnerLoop);
- Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
- }
- return Changed;
-}
-
PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
LoopStandardAnalysisResults &AR,
LPMUpdater &U) {
@@ -949,8 +933,14 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
// in simplified form, and also needs LCSSA. Running
// this pass will simplify all loops that contain inner loops,
// regardless of whether anything ends up being flattened.
- Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,
- MSSAU ? &*MSSAU : nullptr);
+ for (Loop *InnerLoop : LN.getLoops()) {
+ auto *OuterLoop = InnerLoop->getParentLoop();
+ if (!OuterLoop)
+ continue;
+ FlattenInfo FI(OuterLoop, InnerLoop);
+ Changed |= FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,
+ MSSAU ? &*MSSAU : nullptr);
+ }
if (!Changed)
return PreservedAnalyses::all();
@@ -963,60 +953,3 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
PA.preserve<MemorySSAAnalysis>();
return PA;
}
-
-namespace {
-class LoopFlattenLegacyPass : public FunctionPass {
-public:
- static char ID; // Pass ID, replacement for typeid
- LoopFlattenLegacyPass() : FunctionPass(ID) {
- initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry());
- }
-
- // Possibly flatten loop L into its child.
- bool runOnFunction(Function &F) override;
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- getLoopAnalysisUsage(AU);
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addPreserved<TargetTransformInfoWrapperPass>();
- AU.addRequired<AssumptionCacheTracker>();
- AU.addPreserved<AssumptionCacheTracker>();
- AU.addPreserved<MemorySSAWrapperPass>();
- }
-};
-} // namespace
-
-char LoopFlattenLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
- false, false)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
-INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
- false, false)
-
-FunctionPass *llvm::createLoopFlattenPass() {
- return new LoopFlattenLegacyPass();
-}
-
-bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
- ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
- auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
- DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr;
- auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>();
- auto *TTI = &TTIP.getTTI(F);
- auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
- auto *MSSA = getAnalysisIfAvailable<MemorySSAWrapperPass>();
-
- std::optional<MemorySSAUpdater> MSSAU;
- if (MSSA)
- MSSAU = MemorySSAUpdater(&MSSA->getMSSA());
-
- bool Changed = false;
- for (Loop *L : *LI) {
- auto LN = LoopNest::getLoopNest(*L, *SE);
- Changed |=
- Flatten(*LN, DT, LI, SE, AC, TTI, nullptr, MSSAU ? &*MSSAU : nullptr);
- }
- return Changed;
-}