diff options
Diffstat (limited to 'lib/Transforms/IPO/BlockExtractor.cpp')
-rw-r--r-- | lib/Transforms/IPO/BlockExtractor.cpp | 122 |
1 files changed, 88 insertions, 34 deletions
diff --git a/lib/Transforms/IPO/BlockExtractor.cpp b/lib/Transforms/IPO/BlockExtractor.cpp index ff5ee817da49..6c365f3f3cbe 100644 --- a/lib/Transforms/IPO/BlockExtractor.cpp +++ b/lib/Transforms/IPO/BlockExtractor.cpp @@ -1,9 +1,8 @@ //===- BlockExtractor.cpp - Extracts blocks into their own functions ------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -23,6 +22,7 @@ #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CodeExtractor.h" + using namespace llvm; #define DEBUG_TYPE "block-extractor" @@ -36,22 +36,48 @@ static cl::opt<std::string> BlockExtractorFile( cl::opt<bool> BlockExtractorEraseFuncs("extract-blocks-erase-funcs", cl::desc("Erase the existing functions"), cl::Hidden); - namespace { class BlockExtractor : public ModulePass { - SmallVector<BasicBlock *, 16> Blocks; + SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupsOfBlocks; bool EraseFunctions; - SmallVector<std::pair<std::string, std::string>, 32> BlocksByName; + /// Map a function name to groups of blocks. + SmallVector<std::pair<std::string, SmallVector<std::string, 4>>, 4> + BlocksByName; + + void init(const SmallVectorImpl<SmallVector<BasicBlock *, 16>> + &GroupsOfBlocksToExtract) { + for (const SmallVectorImpl<BasicBlock *> &GroupOfBlocks : + GroupsOfBlocksToExtract) { + SmallVector<BasicBlock *, 16> NewGroup; + NewGroup.append(GroupOfBlocks.begin(), GroupOfBlocks.end()); + GroupsOfBlocks.emplace_back(NewGroup); + } + if (!BlockExtractorFile.empty()) + loadFile(); + } public: static char ID; BlockExtractor(const SmallVectorImpl<BasicBlock *> &BlocksToExtract, bool EraseFunctions) - : ModulePass(ID), Blocks(BlocksToExtract.begin(), BlocksToExtract.end()), - EraseFunctions(EraseFunctions) { - if (!BlockExtractorFile.empty()) - loadFile(); + : ModulePass(ID), EraseFunctions(EraseFunctions) { + // We want one group per element of the input list. + SmallVector<SmallVector<BasicBlock *, 16>, 4> MassagedGroupsOfBlocks; + for (BasicBlock *BB : BlocksToExtract) { + SmallVector<BasicBlock *, 16> NewGroup; + NewGroup.push_back(BB); + MassagedGroupsOfBlocks.push_back(NewGroup); + } + init(MassagedGroupsOfBlocks); } + + BlockExtractor(const SmallVectorImpl<SmallVector<BasicBlock *, 16>> + &GroupsOfBlocksToExtract, + bool EraseFunctions) + : ModulePass(ID), EraseFunctions(EraseFunctions) { + init(GroupsOfBlocksToExtract); + } + BlockExtractor() : BlockExtractor(SmallVector<BasicBlock *, 0>(), false) {} bool runOnModule(Module &M) override; @@ -70,6 +96,12 @@ ModulePass *llvm::createBlockExtractorPass( const SmallVectorImpl<BasicBlock *> &BlocksToExtract, bool EraseFunctions) { return new BlockExtractor(BlocksToExtract, EraseFunctions); } +ModulePass *llvm::createBlockExtractorPass( + const SmallVectorImpl<SmallVector<BasicBlock *, 16>> + &GroupsOfBlocksToExtract, + bool EraseFunctions) { + return new BlockExtractor(GroupsOfBlocksToExtract, EraseFunctions); +} /// Gets all of the blocks specified in the input file. void BlockExtractor::loadFile() { @@ -82,8 +114,17 @@ void BlockExtractor::loadFile() { Buf->getBuffer().split(Lines, '\n', /*MaxSplit=*/-1, /*KeepEmpty=*/false); for (const auto &Line : Lines) { - auto FBPair = Line.split(' '); - BlocksByName.push_back({FBPair.first, FBPair.second}); + SmallVector<StringRef, 4> LineSplit; + Line.split(LineSplit, ' ', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + if (LineSplit.empty()) + continue; + SmallVector<StringRef, 4> BBNames; + LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + if (BBNames.empty()) + report_fatal_error("Missing bbs name"); + BlocksByName.push_back({LineSplit[0], {BBNames.begin(), BBNames.end()}}); } } @@ -130,33 +171,46 @@ bool BlockExtractor::runOnModule(Module &M) { } // Get all the blocks specified in the input file. + unsigned NextGroupIdx = GroupsOfBlocks.size(); + GroupsOfBlocks.resize(NextGroupIdx + BlocksByName.size()); for (const auto &BInfo : BlocksByName) { Function *F = M.getFunction(BInfo.first); if (!F) report_fatal_error("Invalid function name specified in the input file"); - auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) { - return BB.getName().equals(BInfo.second); - }); - if (Res == F->end()) - report_fatal_error("Invalid block name specified in the input file"); - Blocks.push_back(&*Res); + for (const auto &BBInfo : BInfo.second) { + auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) { + return BB.getName().equals(BBInfo); + }); + if (Res == F->end()) + report_fatal_error("Invalid block name specified in the input file"); + GroupsOfBlocks[NextGroupIdx].push_back(&*Res); + } + ++NextGroupIdx; } - // Extract basic blocks. - for (BasicBlock *BB : Blocks) { - // Check if the module contains BB. - if (BB->getParent()->getParent() != &M) - report_fatal_error("Invalid basic block"); - LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting " - << BB->getParent()->getName() << ":" << BB->getName() - << "\n"); - SmallVector<BasicBlock *, 2> BlocksToExtractVec; - BlocksToExtractVec.push_back(BB); - if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) - BlocksToExtractVec.push_back(II->getUnwindDest()); - CodeExtractor(BlocksToExtractVec).extractCodeRegion(); - ++NumExtracted; - Changed = true; + // Extract each group of basic blocks. + for (auto &BBs : GroupsOfBlocks) { + SmallVector<BasicBlock *, 32> BlocksToExtractVec; + for (BasicBlock *BB : BBs) { + // Check if the module contains BB. + if (BB->getParent()->getParent() != &M) + report_fatal_error("Invalid basic block"); + LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting " + << BB->getParent()->getName() << ":" << BB->getName() + << "\n"); + BlocksToExtractVec.push_back(BB); + if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) + BlocksToExtractVec.push_back(II->getUnwindDest()); + ++NumExtracted; + Changed = true; + } + Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(); + if (F) + LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName() + << "' in: " << F->getName() << '\n'); + else + LLVM_DEBUG(dbgs() << "Failed to extract for group '" + << (*BBs.begin())->getName() << "'\n"); } // Erase the functions. |