summaryrefslogtreecommitdiff
path: root/lib/Transforms/Utils/CodeExtractor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/Utils/CodeExtractor.cpp')
-rw-r--r--lib/Transforms/Utils/CodeExtractor.cpp227
1 files changed, 183 insertions, 44 deletions
diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp
index 7a404241cb14..f31dab9f96af 100644
--- a/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/lib/Transforms/Utils/CodeExtractor.cpp
@@ -66,6 +66,7 @@
#include <vector>
using namespace llvm;
+using ProfileCount = Function::ProfileCount;
#define DEBUG_TYPE "code-extractor"
@@ -77,12 +78,10 @@ static cl::opt<bool>
AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
cl::desc("Aggregate arguments to code-extracted functions"));
-/// \brief Test whether a block is valid for extraction.
-bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB,
- bool AllowVarArgs) {
- // Landing pads must be in the function where they were inserted for cleanup.
- if (BB.isEHPad())
- return false;
+/// Test whether a block is valid for extraction.
+static bool isBlockValidForExtraction(const BasicBlock &BB,
+ const SetVector<BasicBlock *> &Result,
+ bool AllowVarArgs, bool AllowAlloca) {
// taking the address of a basic block moved to another function is illegal
if (BB.hasAddressTaken())
return false;
@@ -111,11 +110,63 @@ bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB,
}
}
- // Don't hoist code containing allocas or invokes. If explicitly requested,
- // allow vastart.
+ // If explicitly requested, allow vastart and alloca. For invoke instructions
+ // verify that extraction is valid.
for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
- if (isa<AllocaInst>(I) || isa<InvokeInst>(I))
- return false;
+ if (isa<AllocaInst>(I)) {
+ if (!AllowAlloca)
+ return false;
+ continue;
+ }
+
+ if (const auto *II = dyn_cast<InvokeInst>(I)) {
+ // Unwind destination (either a landingpad, catchswitch, or cleanuppad)
+ // must be a part of the subgraph which is being extracted.
+ if (auto *UBB = II->getUnwindDest())
+ if (!Result.count(UBB))
+ return false;
+ continue;
+ }
+
+ // All catch handlers of a catchswitch instruction as well as the unwind
+ // destination must be in the subgraph.
+ if (const auto *CSI = dyn_cast<CatchSwitchInst>(I)) {
+ if (auto *UBB = CSI->getUnwindDest())
+ if (!Result.count(UBB))
+ return false;
+ for (auto *HBB : CSI->handlers())
+ if (!Result.count(const_cast<BasicBlock*>(HBB)))
+ return false;
+ continue;
+ }
+
+ // Make sure that entire catch handler is within subgraph. It is sufficient
+ // to check that catch return's block is in the list.
+ if (const auto *CPI = dyn_cast<CatchPadInst>(I)) {
+ for (const auto *U : CPI->users())
+ if (const auto *CRI = dyn_cast<CatchReturnInst>(U))
+ if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
+ return false;
+ continue;
+ }
+
+ // And do similar checks for cleanup handler - the entire handler must be
+ // in subgraph which is going to be extracted. For cleanup return should
+ // additionally check that the unwind destination is also in the subgraph.
+ if (const auto *CPI = dyn_cast<CleanupPadInst>(I)) {
+ for (const auto *U : CPI->users())
+ if (const auto *CRI = dyn_cast<CleanupReturnInst>(U))
+ if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
+ return false;
+ continue;
+ }
+ if (const auto *CRI = dyn_cast<CleanupReturnInst>(I)) {
+ if (auto *UBB = CRI->getUnwindDest())
+ if (!Result.count(UBB))
+ return false;
+ continue;
+ }
+
if (const CallInst *CI = dyn_cast<CallInst>(I))
if (const Function *F = CI->getCalledFunction())
if (F->getIntrinsicID() == Intrinsic::vastart) {
@@ -129,10 +180,10 @@ bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB,
return true;
}
-/// \brief Build a set of blocks to extract if the input blocks are viable.
+/// Build a set of blocks to extract if the input blocks are viable.
static SetVector<BasicBlock *>
buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
- bool AllowVarArgs) {
+ bool AllowVarArgs, bool AllowAlloca) {
assert(!BBs.empty() && "The set of blocks to extract must be non-empty");
SetVector<BasicBlock *> Result;
@@ -145,32 +196,42 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
if (!Result.insert(BB))
llvm_unreachable("Repeated basic blocks in extraction input");
- if (!CodeExtractor::isBlockValidForExtraction(*BB, AllowVarArgs)) {
- Result.clear();
- return Result;
- }
}
-#ifndef NDEBUG
- for (SetVector<BasicBlock *>::iterator I = std::next(Result.begin()),
- E = Result.end();
- I != E; ++I)
- for (pred_iterator PI = pred_begin(*I), PE = pred_end(*I);
- PI != PE; ++PI)
- assert(Result.count(*PI) &&
- "No blocks in this region may have entries from outside the region"
- " except for the first block!");
-#endif
+ for (auto *BB : Result) {
+ if (!isBlockValidForExtraction(*BB, Result, AllowVarArgs, AllowAlloca))
+ return {};
+
+ // Make sure that the first block is not a landing pad.
+ if (BB == Result.front()) {
+ if (BB->isEHPad()) {
+ LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n");
+ return {};
+ }
+ continue;
+ }
+
+ // All blocks other than the first must not have predecessors outside of
+ // the subgraph which is being extracted.
+ for (auto *PBB : predecessors(BB))
+ if (!Result.count(PBB)) {
+ LLVM_DEBUG(
+ dbgs() << "No blocks in this region may have entries from "
+ "outside the region except for the first block!\n");
+ return {};
+ }
+ }
return Result;
}
CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
bool AggregateArgs, BlockFrequencyInfo *BFI,
- BranchProbabilityInfo *BPI, bool AllowVarArgs)
+ BranchProbabilityInfo *BPI, bool AllowVarArgs,
+ bool AllowAlloca)
: DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
BPI(BPI), AllowVarArgs(AllowVarArgs),
- Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs)) {}
+ Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)) {}
CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
BlockFrequencyInfo *BFI,
@@ -178,7 +239,8 @@ CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
: DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
BPI(BPI), AllowVarArgs(false),
Blocks(buildExtractionBlockSet(L.getBlocks(), &DT,
- /* AllowVarArgs */ false)) {}
+ /* AllowVarArgs */ false,
+ /* AllowAlloca */ false)) {}
/// definedInRegion - Return true if the specified value is defined in the
/// extracted region.
@@ -562,8 +624,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
BasicBlock *newHeader,
Function *oldFunction,
Module *M) {
- DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
- DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
+ LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
+ LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
// This function returns unsigned, outputs will go back by reference.
switch (NumExitBlocks) {
@@ -577,20 +639,20 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
// Add the types of the input values to the function's argument list
for (Value *value : inputs) {
- DEBUG(dbgs() << "value used in func: " << *value << "\n");
+ LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
paramTy.push_back(value->getType());
}
// Add the types of the output values to the function's argument list.
for (Value *output : outputs) {
- DEBUG(dbgs() << "instr used in func: " << *output << "\n");
+ LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
if (AggregateArgs)
paramTy.push_back(output->getType());
else
paramTy.push_back(PointerType::getUnqual(output->getType()));
}
- DEBUG({
+ LLVM_DEBUG({
dbgs() << "Function type: " << *RetTy << " f(";
for (Type *i : paramTy)
dbgs() << *i << ", ";
@@ -620,16 +682,89 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
if (oldFunction->hasUWTable())
newFunction->setHasUWTable();
- // Inherit all of the target dependent attributes.
+ // Inherit all of the target dependent attributes and white-listed
+ // target independent attributes.
// (e.g. If the extracted region contains a call to an x86.sse
// instruction we need to make sure that the extracted region has the
// "target-features" attribute allowing it to be lowered.
// FIXME: This should be changed to check to see if a specific
// attribute can not be inherited.
- AttrBuilder AB(oldFunction->getAttributes().getFnAttributes());
- for (const auto &Attr : AB.td_attrs())
- newFunction->addFnAttr(Attr.first, Attr.second);
+ for (const auto &Attr : oldFunction->getAttributes().getFnAttributes()) {
+ if (Attr.isStringAttribute()) {
+ if (Attr.getKindAsString() == "thunk")
+ continue;
+ } else
+ switch (Attr.getKindAsEnum()) {
+ // Those attributes cannot be propagated safely. Explicitly list them
+ // here so we get a warning if new attributes are added. This list also
+ // includes non-function attributes.
+ case Attribute::Alignment:
+ case Attribute::AllocSize:
+ case Attribute::ArgMemOnly:
+ case Attribute::Builtin:
+ case Attribute::ByVal:
+ case Attribute::Convergent:
+ case Attribute::Dereferenceable:
+ case Attribute::DereferenceableOrNull:
+ case Attribute::InAlloca:
+ case Attribute::InReg:
+ case Attribute::InaccessibleMemOnly:
+ case Attribute::InaccessibleMemOrArgMemOnly:
+ case Attribute::JumpTable:
+ case Attribute::Naked:
+ case Attribute::Nest:
+ case Attribute::NoAlias:
+ case Attribute::NoBuiltin:
+ case Attribute::NoCapture:
+ case Attribute::NoReturn:
+ case Attribute::None:
+ case Attribute::NonNull:
+ case Attribute::ReadNone:
+ case Attribute::ReadOnly:
+ case Attribute::Returned:
+ case Attribute::ReturnsTwice:
+ case Attribute::SExt:
+ case Attribute::Speculatable:
+ case Attribute::StackAlignment:
+ case Attribute::StructRet:
+ case Attribute::SwiftError:
+ case Attribute::SwiftSelf:
+ case Attribute::WriteOnly:
+ case Attribute::ZExt:
+ case Attribute::EndAttrKinds:
+ continue;
+ // Those attributes should be safe to propagate to the extracted function.
+ case Attribute::AlwaysInline:
+ case Attribute::Cold:
+ case Attribute::NoRecurse:
+ case Attribute::InlineHint:
+ case Attribute::MinSize:
+ case Attribute::NoDuplicate:
+ case Attribute::NoImplicitFloat:
+ case Attribute::NoInline:
+ case Attribute::NonLazyBind:
+ case Attribute::NoRedZone:
+ case Attribute::NoUnwind:
+ case Attribute::OptForFuzzing:
+ case Attribute::OptimizeNone:
+ case Attribute::OptimizeForSize:
+ case Attribute::SafeStack:
+ case Attribute::ShadowCallStack:
+ case Attribute::SanitizeAddress:
+ case Attribute::SanitizeMemory:
+ case Attribute::SanitizeThread:
+ case Attribute::SanitizeHWAddress:
+ case Attribute::StackProtect:
+ case Attribute::StackProtectReq:
+ case Attribute::StackProtectStrong:
+ case Attribute::StrictFP:
+ case Attribute::UWTable:
+ case Attribute::NoCfCheck:
+ break;
+ }
+ newFunction->addFnAttr(Attr);
+ }
newFunction->getBasicBlockList().push_back(newRootNode);
// Create an iterator to name all of the arguments we inserted.
@@ -1093,10 +1228,10 @@ Function *CodeExtractor::extractCodeRegion() {
// Update the entry count of the function.
if (BFI) {
- Optional<uint64_t> EntryCount =
- BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
- if (EntryCount.hasValue())
- newFunction->setEntryCount(EntryCount.getValue());
+ auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
+ if (Count.hasValue())
+ newFunction->setEntryCount(
+ ProfileCount(Count.getValue(), Function::PCT_Real)); // FIXME
BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
}
@@ -1104,6 +1239,10 @@ Function *CodeExtractor::extractCodeRegion() {
moveCodeToFunction(newFunction);
+ // Propagate personality info to the new function if there is one.
+ if (oldFunction->hasPersonalityFn())
+ newFunction->setPersonalityFn(oldFunction->getPersonalityFn());
+
// Update the branch weights for the exit block.
if (BFI && NumExitBlocks > 1)
calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
@@ -1139,7 +1278,7 @@ Function *CodeExtractor::extractCodeRegion() {
}
}
- DEBUG(if (verifyFunction(*newFunction))
- report_fatal_error("verifyFunction failed!"));
+ LLVM_DEBUG(if (verifyFunction(*newFunction))
+ report_fatal_error("verifyFunction failed!"));
return newFunction;
}