diff options
Diffstat (limited to 'llvm/lib/Transforms/IPO/LoopExtractor.cpp')
-rw-r--r-- | llvm/lib/Transforms/IPO/LoopExtractor.cpp | 167 |
1 files changed, 167 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/IPO/LoopExtractor.cpp b/llvm/lib/Transforms/IPO/LoopExtractor.cpp new file mode 100644 index 000000000000..add2ae053735 --- /dev/null +++ b/llvm/lib/Transforms/IPO/LoopExtractor.cpp @@ -0,0 +1,167 @@ +//===- LoopExtractor.cpp - Extract each loop into a new function ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A pass wrapper around the ExtractLoop() scalar transformation to extract each +// top-level loop into its own new function. If the loop is the ONLY loop in a +// given function, it is not touched. This is a pass most useful for debugging +// via bugpoint. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/CodeExtractor.h" +#include <fstream> +#include <set> +using namespace llvm; + +#define DEBUG_TYPE "loop-extract" + +STATISTIC(NumExtracted, "Number of loops extracted"); + +namespace { + struct LoopExtractor : public LoopPass { + static char ID; // Pass identification, replacement for typeid + unsigned NumLoops; + + explicit LoopExtractor(unsigned numLoops = ~0) + : LoopPass(ID), NumLoops(numLoops) { + initializeLoopExtractorPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequiredID(BreakCriticalEdgesID); + AU.addRequiredID(LoopSimplifyID); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addUsedIfAvailable<AssumptionCacheTracker>(); + } + }; +} + +char LoopExtractor::ID = 0; +INITIALIZE_PASS_BEGIN(LoopExtractor, "loop-extract", + "Extract loops into new functions", false, false) +INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(LoopExtractor, "loop-extract", + "Extract loops into new functions", false, false) + +namespace { + /// SingleLoopExtractor - For bugpoint. + struct SingleLoopExtractor : public LoopExtractor { + static char ID; // Pass identification, replacement for typeid + SingleLoopExtractor() : LoopExtractor(1) {} + }; +} // End anonymous namespace + +char SingleLoopExtractor::ID = 0; +INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single", + "Extract at most one loop into a new function", false, false) + +// createLoopExtractorPass - This pass extracts all natural loops from the +// program into a function if it can. +// +Pass *llvm::createLoopExtractorPass() { return new LoopExtractor(); } + +bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &LPM) { + if (skipLoop(L)) + return false; + + // Only visit top-level loops. + if (L->getParentLoop()) + return false; + + // If LoopSimplify form is not available, stay out of trouble. + if (!L->isLoopSimplifyForm()) + return false; + + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + bool Changed = false; + + // If there is more than one top-level loop in this function, extract all of + // the loops. Otherwise there is exactly one top-level loop; in this case if + // this function is more than a minimal wrapper around the loop, extract + // the loop. + bool ShouldExtractLoop = false; + + // Extract the loop if the entry block doesn't branch to the loop header. + Instruction *EntryTI = + L->getHeader()->getParent()->getEntryBlock().getTerminator(); + if (!isa<BranchInst>(EntryTI) || + !cast<BranchInst>(EntryTI)->isUnconditional() || + EntryTI->getSuccessor(0) != L->getHeader()) { + ShouldExtractLoop = true; + } else { + // Check to see if any exits from the loop are more than just return + // blocks. + SmallVector<BasicBlock*, 8> ExitBlocks; + L->getExitBlocks(ExitBlocks); + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) + if (!isa<ReturnInst>(ExitBlocks[i]->getTerminator())) { + ShouldExtractLoop = true; + break; + } + } + + if (ShouldExtractLoop) { + // We must omit EH pads. EH pads must accompany the invoke + // instruction. But this would result in a loop in the extracted + // function. An infinite cycle occurs when it tries to extract that loop as + // well. + SmallVector<BasicBlock*, 8> ExitBlocks; + L->getExitBlocks(ExitBlocks); + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) + if (ExitBlocks[i]->isEHPad()) { + ShouldExtractLoop = false; + break; + } + } + + if (ShouldExtractLoop) { + if (NumLoops == 0) return Changed; + --NumLoops; + AssumptionCache *AC = nullptr; + Function &Func = *L->getHeader()->getParent(); + if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>()) + AC = ACT->lookupAssumptionCache(Func); + CodeExtractorAnalysisCache CEAC(Func); + CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC); + if (Extractor.extractCodeRegion(CEAC) != nullptr) { + Changed = true; + // After extraction, the loop is replaced by a function call, so + // we shouldn't try to run any more loop passes on it. + LPM.markLoopAsDeleted(*L); + LI.erase(L); + } + ++NumExtracted; + } + + return Changed; +} + +// createSingleLoopExtractorPass - This pass extracts one natural loop from the +// program into a function if it can. This is used by bugpoint. +// +Pass *llvm::createSingleLoopExtractorPass() { + return new SingleLoopExtractor(); +} |