diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopFlatten.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LoopFlatten.cpp | 97 |
1 files changed, 15 insertions, 82 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index 7d9ce8d35e0b..edc8a4956dd1 100644 --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -65,11 +65,8 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -318,12 +315,12 @@ static bool verifyTripCount(Value *RHS, Loop *L, return false; } - // The Extend=false flag is used for getTripCountFromExitCount as we want - // to verify and match it with the pattern matched tripcount. Please note - // that overflow checks are performed in checkOverflow, but are first tried - // to avoid by widening the IV. + // Evaluating in the trip count's type can not overflow here as the overflow + // checks are performed in checkOverflow, but are first tried to avoid by + // widening the IV. const SCEV *SCEVTripCount = - SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false); + SE->getTripCountFromExitCount(BackedgeTakenCount, + BackedgeTakenCount->getType(), L); const SCEV *SCEVRHS = SE->getSCEV(RHS); if (SCEVRHS == SCEVTripCount) @@ -336,7 +333,8 @@ static bool verifyTripCount(Value *RHS, Loop *L, // Find the extended backedge taken count and extended trip count using // SCEV. One of these should now match the RHS of the compare. BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType()); - SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false); + SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, + RHS->getType(), L); if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) { LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); return false; @@ -918,20 +916,6 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); } -bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U, - MemorySSAUpdater *MSSAU) { - bool Changed = false; - for (Loop *InnerLoop : LN.getLoops()) { - auto *OuterLoop = InnerLoop->getParentLoop(); - if (!OuterLoop) - continue; - FlattenInfo FI(OuterLoop, InnerLoop); - Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); - } - return Changed; -} - PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, LoopStandardAnalysisResults &AR, LPMUpdater &U) { @@ -949,8 +933,14 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, // in simplified form, and also needs LCSSA. Running // this pass will simplify all loops that contain inner loops, // regardless of whether anything ends up being flattened. - Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U, - MSSAU ? &*MSSAU : nullptr); + for (Loop *InnerLoop : LN.getLoops()) { + auto *OuterLoop = InnerLoop->getParentLoop(); + if (!OuterLoop) + continue; + FlattenInfo FI(OuterLoop, InnerLoop); + Changed |= FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U, + MSSAU ? &*MSSAU : nullptr); + } if (!Changed) return PreservedAnalyses::all(); @@ -963,60 +953,3 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, PA.preserve<MemorySSAAnalysis>(); return PA; } - -namespace { -class LoopFlattenLegacyPass : public FunctionPass { -public: - static char ID; // Pass ID, replacement for typeid - LoopFlattenLegacyPass() : FunctionPass(ID) { - initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - // Possibly flatten loop L into its child. - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - getLoopAnalysisUsage(AU); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addPreserved<TargetTransformInfoWrapperPass>(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addPreserved<AssumptionCacheTracker>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } -}; -} // namespace - -char LoopFlattenLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops", - false, false) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops", - false, false) - -FunctionPass *llvm::createLoopFlattenPass() { - return new LoopFlattenLegacyPass(); -} - -bool LoopFlattenLegacyPass::runOnFunction(Function &F) { - ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; - auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>(); - auto *TTI = &TTIP.getTTI(F); - auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto *MSSA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - - std::optional<MemorySSAUpdater> MSSAU; - if (MSSA) - MSSAU = MemorySSAUpdater(&MSSA->getMSSA()); - - bool Changed = false; - for (Loop *L : *LI) { - auto LN = LoopNest::getLoopNest(*L, *SE); - Changed |= - Flatten(*LN, DT, LI, SE, AC, TTI, nullptr, MSSAU ? &*MSSAU : nullptr); - } - return Changed; -} |