summaryrefslogtreecommitdiff
path: root/lib/Transforms/Utils/CodeExtractor.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2017-01-02 19:17:04 +0000
committerDimitry Andric <dim@FreeBSD.org>2017-01-02 19:17:04 +0000
commitb915e9e0fc85ba6f398b3fab0db6a81a8913af94 (patch)
tree98b8f811c7aff2547cab8642daf372d6c59502fb /lib/Transforms/Utils/CodeExtractor.cpp
parent6421cca32f69ac849537a3cff78c352195e99f1b (diff)
Notes
Diffstat (limited to 'lib/Transforms/Utils/CodeExtractor.cpp')
-rw-r--r--lib/Transforms/Utils/CodeExtractor.cpp147
1 files changed, 128 insertions, 19 deletions
diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp
index 9f2181f87cee..c514c9c9cd4a 100644
--- a/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/lib/Transforms/Utils/CodeExtractor.cpp
@@ -17,6 +17,9 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/RegionInfo.h"
#include "llvm/Analysis/RegionIterator.h"
@@ -26,9 +29,11 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h"
+#include "llvm/Support/BlockFrequency.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
@@ -49,7 +54,7 @@ AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
cl::desc("Aggregate arguments to code-extracted functions"));
/// \brief Test whether a block is valid for extraction.
-static bool isBlockValidForExtraction(const BasicBlock &BB) {
+bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) {
// Landing pads must be in the function where they were inserted for cleanup.
if (BB.isEHPad())
return false;
@@ -81,7 +86,7 @@ static SetVector<BasicBlock *> buildExtractionBlockSet(IteratorT BBBegin,
if (!Result.insert(*BBBegin))
llvm_unreachable("Repeated basic blocks in extraction input");
- if (!isBlockValidForExtraction(**BBBegin)) {
+ if (!CodeExtractor::isBlockValidForExtraction(**BBBegin)) {
Result.clear();
return Result;
}
@@ -119,23 +124,30 @@ buildExtractionBlockSet(const RegionNode &RN) {
return buildExtractionBlockSet(R.block_begin(), R.block_end());
}
-CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs)
- : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt),
- Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
+CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs,
+ BlockFrequencyInfo *BFI,
+ BranchProbabilityInfo *BPI)
+ : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+ BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
- bool AggregateArgs)
- : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
- Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
-
-CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs)
- : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
- Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {}
+ bool AggregateArgs, BlockFrequencyInfo *BFI,
+ BranchProbabilityInfo *BPI)
+ : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+ BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
+
+CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
+ BlockFrequencyInfo *BFI,
+ BranchProbabilityInfo *BPI)
+ : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+ BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())),
+ NumExitBlocks(~0U) {}
CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN,
- bool AggregateArgs)
- : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
- Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
+ bool AggregateArgs, BlockFrequencyInfo *BFI,
+ BranchProbabilityInfo *BPI)
+ : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+ BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
/// definedInRegion - Return true if the specified value is defined in the
/// extracted region.
@@ -339,7 +351,22 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
// If the old function is no-throw, so is the new one.
if (oldFunction->doesNotThrow())
newFunction->setDoesNotThrow();
-
+
+ // Inherit the uwtable attribute if we need to.
+ if (oldFunction->hasUWTable())
+ newFunction->setHasUWTable();
+
+ // Inherit all of the target dependent attributes.
+ // (e.g. If the extracted region contains a call to an x86.sse
+ // instruction we need to make sure that the extracted region has the
+ // "target-features" attribute allowing it to be lowered.
+ // FIXME: This should be changed to check to see if a specific
+ // attribute can not be inherited.
+ AttributeSet OldFnAttrs = oldFunction->getAttributes().getFnAttributes();
+ AttrBuilder AB(OldFnAttrs, AttributeSet::FunctionIndex);
+ for (auto Attr : AB.td_attrs())
+ newFunction->addFnAttr(Attr.first, Attr.second);
+
newFunction->getBasicBlockList().push_back(newRootNode);
// Create an iterator to name all of the arguments we inserted.
@@ -672,6 +699,51 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) {
}
}
+void CodeExtractor::calculateNewCallTerminatorWeights(
+ BasicBlock *CodeReplacer,
+ DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
+ BranchProbabilityInfo *BPI) {
+ typedef BlockFrequencyInfoImplBase::Distribution Distribution;
+ typedef BlockFrequencyInfoImplBase::BlockNode BlockNode;
+
+ // Update the branch weights for the exit block.
+ TerminatorInst *TI = CodeReplacer->getTerminator();
+ SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
+
+ // Block Frequency distribution with dummy node.
+ Distribution BranchDist;
+
+ // Add each of the frequencies of the successors.
+ for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
+ BlockNode ExitNode(i);
+ uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
+ if (ExitFreq != 0)
+ BranchDist.addExit(ExitNode, ExitFreq);
+ else
+ BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
+ }
+
+ // Check for no total weight.
+ if (BranchDist.Total == 0)
+ return;
+
+ // Normalize the distribution so that they can fit in unsigned.
+ BranchDist.normalize();
+
+ // Create normalized branch weights and set the metadata.
+ for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
+ const auto &Weight = BranchDist.Weights[I];
+
+ // Get the weight and update the current BFI.
+ BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
+ BranchProbability BP(Weight.Amount, BranchDist.Total);
+ BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
+ }
+ TI->setMetadata(
+ LLVMContext::MD_prof,
+ MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
+}
+
Function *CodeExtractor::extractCodeRegion() {
if (!isEligible())
return nullptr;
@@ -682,6 +754,19 @@ Function *CodeExtractor::extractCodeRegion() {
// block in the region.
BasicBlock *header = *Blocks.begin();
+ // Calculate the entry frequency of the new function before we change the root
+ // block.
+ BlockFrequency EntryFreq;
+ if (BFI) {
+ assert(BPI && "Both BPI and BFI are required to preserve profile info");
+ for (BasicBlock *Pred : predecessors(header)) {
+ if (Blocks.count(Pred))
+ continue;
+ EntryFreq +=
+ BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
+ }
+ }
+
// If we have to split PHI nodes or the entry block, do so now.
severSplitPHINodes(header);
@@ -705,12 +790,23 @@ Function *CodeExtractor::extractCodeRegion() {
// Find inputs to, outputs from the code region.
findInputsOutputs(inputs, outputs);
+ // Calculate the exit blocks for the extracted region and the total exit
+ // weights for each of those blocks.
+ DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
SmallPtrSet<BasicBlock *, 1> ExitBlocks;
- for (BasicBlock *Block : Blocks)
+ for (BasicBlock *Block : Blocks) {
for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
- ++SI)
- if (!Blocks.count(*SI))
+ ++SI) {
+ if (!Blocks.count(*SI)) {
+ // Update the branch weight for this successor.
+ if (BFI) {
+ BlockFrequency &BF = ExitWeights[*SI];
+ BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
+ }
ExitBlocks.insert(*SI);
+ }
+ }
+ }
NumExitBlocks = ExitBlocks.size();
// Construct new function based on inputs/outputs & add allocas for all defs.
@@ -719,10 +815,23 @@ Function *CodeExtractor::extractCodeRegion() {
codeReplacer, oldFunction,
oldFunction->getParent());
+ // Update the entry count of the function.
+ if (BFI) {
+ Optional<uint64_t> EntryCount =
+ BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
+ if (EntryCount.hasValue())
+ newFunction->setEntryCount(EntryCount.getValue());
+ BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
+ }
+
emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
moveCodeToFunction(newFunction);
+ // Update the branch weights for the exit block.
+ if (BFI && NumExitBlocks > 1)
+ calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
+
// Loop over all of the PHI nodes in the header block, and change any
// references to the old incoming edge to be the new incoming edge.
for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {