diff options
Diffstat (limited to 'contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp')
| -rw-r--r-- | contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp | 356 |
1 files changed, 291 insertions, 65 deletions
diff --git a/contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp index c514c9c9cd4a..5d57ed9718fb 100644 --- a/contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -27,6 +27,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" @@ -73,24 +74,26 @@ bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) { } /// \brief Build a set of blocks to extract if the input blocks are viable. -template <typename IteratorT> -static SetVector<BasicBlock *> buildExtractionBlockSet(IteratorT BBBegin, - IteratorT BBEnd) { +static SetVector<BasicBlock *> +buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT) { + assert(!BBs.empty() && "The set of blocks to extract must be non-empty"); SetVector<BasicBlock *> Result; - assert(BBBegin != BBEnd); - // Loop over the blocks, adding them to our set-vector, and aborting with an // empty set if we encounter invalid blocks. - do { - if (!Result.insert(*BBBegin)) - llvm_unreachable("Repeated basic blocks in extraction input"); + for (BasicBlock *BB : BBs) { + + // If this block is dead, don't process it. + if (DT && !DT->isReachableFromEntry(BB)) + continue; - if (!CodeExtractor::isBlockValidForExtraction(**BBBegin)) { + if (!Result.insert(BB)) + llvm_unreachable("Repeated basic blocks in extraction input"); + if (!CodeExtractor::isBlockValidForExtraction(*BB)) { Result.clear(); return Result; } - } while (++BBBegin != BBEnd); + } #ifndef NDEBUG for (SetVector<BasicBlock *>::iterator I = std::next(Result.begin()), @@ -106,49 +109,19 @@ static SetVector<BasicBlock *> buildExtractionBlockSet(IteratorT BBBegin, return Result; } -/// \brief Helper to call buildExtractionBlockSet with an ArrayRef. -static SetVector<BasicBlock *> -buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs) { - return buildExtractionBlockSet(BBs.begin(), BBs.end()); -} - -/// \brief Helper to call buildExtractionBlockSet with a RegionNode. -static SetVector<BasicBlock *> -buildExtractionBlockSet(const RegionNode &RN) { - if (!RN.isSubRegion()) - // Just a single BasicBlock. - return buildExtractionBlockSet(RN.getNodeAs<BasicBlock>()); - - const Region &R = *RN.getNodeAs<Region>(); - - return buildExtractionBlockSet(R.block_begin(), R.block_end()); -} - -CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs, - BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI) - : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} - CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} + BPI(BPI), Blocks(buildExtractionBlockSet(BBs, DT)), NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI) : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())), + BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT)), NumExitBlocks(~0U) {} -CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN, - bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI) - : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} - /// definedInRegion - Return true if the specified value is defined in the /// extracted region. static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) { @@ -169,16 +142,255 @@ static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) { return false; } -void CodeExtractor::findInputsOutputs(ValueSet &Inputs, - ValueSet &Outputs) const { +static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) { + BasicBlock *CommonExitBlock = nullptr; + auto hasNonCommonExitSucc = [&](BasicBlock *Block) { + for (auto *Succ : successors(Block)) { + // Internal edges, ok. + if (Blocks.count(Succ)) + continue; + if (!CommonExitBlock) { + CommonExitBlock = Succ; + continue; + } + if (CommonExitBlock == Succ) + continue; + + return true; + } + return false; + }; + + if (any_of(Blocks, hasNonCommonExitSucc)) + return nullptr; + + return CommonExitBlock; +} + +bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers( + Instruction *Addr) const { + AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets()); + Function *Func = (*Blocks.begin())->getParent(); + for (BasicBlock &BB : *Func) { + if (Blocks.count(&BB)) + continue; + for (Instruction &II : BB) { + + if (isa<DbgInfoIntrinsic>(II)) + continue; + + unsigned Opcode = II.getOpcode(); + Value *MemAddr = nullptr; + switch (Opcode) { + case Instruction::Store: + case Instruction::Load: { + if (Opcode == Instruction::Store) { + StoreInst *SI = cast<StoreInst>(&II); + MemAddr = SI->getPointerOperand(); + } else { + LoadInst *LI = cast<LoadInst>(&II); + MemAddr = LI->getPointerOperand(); + } + // Global variable can not be aliased with locals. + if (dyn_cast<Constant>(MemAddr)) + break; + Value *Base = MemAddr->stripInBoundsConstantOffsets(); + if (!dyn_cast<AllocaInst>(Base) || Base == AI) + return false; + break; + } + default: { + IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II); + if (IntrInst) { + if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start || + IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) + break; + return false; + } + // Treat all the other cases conservatively if it has side effects. + if (II.mayHaveSideEffects()) + return false; + } + } + } + } + + return true; +} + +BasicBlock * +CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) { + BasicBlock *SinglePredFromOutlineRegion = nullptr; + assert(!Blocks.count(CommonExitBlock) && + "Expect a block outside the region!"); + for (auto *Pred : predecessors(CommonExitBlock)) { + if (!Blocks.count(Pred)) + continue; + if (!SinglePredFromOutlineRegion) { + SinglePredFromOutlineRegion = Pred; + } else if (SinglePredFromOutlineRegion != Pred) { + SinglePredFromOutlineRegion = nullptr; + break; + } + } + + if (SinglePredFromOutlineRegion) + return SinglePredFromOutlineRegion; + +#ifndef NDEBUG + auto getFirstPHI = [](BasicBlock *BB) { + BasicBlock::iterator I = BB->begin(); + PHINode *FirstPhi = nullptr; + while (I != BB->end()) { + PHINode *Phi = dyn_cast<PHINode>(I); + if (!Phi) + break; + if (!FirstPhi) { + FirstPhi = Phi; + break; + } + } + return FirstPhi; + }; + // If there are any phi nodes, the single pred either exists or has already + // be created before code extraction. + assert(!getFirstPHI(CommonExitBlock) && "Phi not expected"); +#endif + + BasicBlock *NewExitBlock = CommonExitBlock->splitBasicBlock( + CommonExitBlock->getFirstNonPHI()->getIterator()); + + for (auto *Pred : predecessors(CommonExitBlock)) { + if (Blocks.count(Pred)) + continue; + Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock); + } + // Now add the old exit block to the outline region. + Blocks.insert(CommonExitBlock); + return CommonExitBlock; +} + +void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands, + BasicBlock *&ExitBlock) const { + Function *Func = (*Blocks.begin())->getParent(); + ExitBlock = getCommonExitBlock(Blocks); + + for (BasicBlock &BB : *Func) { + if (Blocks.count(&BB)) + continue; + for (Instruction &II : BB) { + auto *AI = dyn_cast<AllocaInst>(&II); + if (!AI) + continue; + + // Find the pair of life time markers for address 'Addr' that are either + // defined inside the outline region or can legally be shrinkwrapped into + // the outline region. If there are not other untracked uses of the + // address, return the pair of markers if found; otherwise return a pair + // of nullptr. + auto GetLifeTimeMarkers = + [&](Instruction *Addr, bool &SinkLifeStart, + bool &HoistLifeEnd) -> std::pair<Instruction *, Instruction *> { + Instruction *LifeStart = nullptr, *LifeEnd = nullptr; + + for (User *U : Addr->users()) { + IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U); + if (IntrInst) { + if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) { + // Do not handle the case where AI has multiple start markers. + if (LifeStart) + return std::make_pair<Instruction *>(nullptr, nullptr); + LifeStart = IntrInst; + } + if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) { + if (LifeEnd) + return std::make_pair<Instruction *>(nullptr, nullptr); + LifeEnd = IntrInst; + } + continue; + } + // Find untracked uses of the address, bail. + if (!definedInRegion(Blocks, U)) + return std::make_pair<Instruction *>(nullptr, nullptr); + } + + if (!LifeStart || !LifeEnd) + return std::make_pair<Instruction *>(nullptr, nullptr); + + SinkLifeStart = !definedInRegion(Blocks, LifeStart); + HoistLifeEnd = !definedInRegion(Blocks, LifeEnd); + // Do legality Check. + if ((SinkLifeStart || HoistLifeEnd) && + !isLegalToShrinkwrapLifetimeMarkers(Addr)) + return std::make_pair<Instruction *>(nullptr, nullptr); + + // Check to see if we have a place to do hoisting, if not, bail. + if (HoistLifeEnd && !ExitBlock) + return std::make_pair<Instruction *>(nullptr, nullptr); + + return std::make_pair(LifeStart, LifeEnd); + }; + + bool SinkLifeStart = false, HoistLifeEnd = false; + auto Markers = GetLifeTimeMarkers(AI, SinkLifeStart, HoistLifeEnd); + + if (Markers.first) { + if (SinkLifeStart) + SinkCands.insert(Markers.first); + SinkCands.insert(AI); + if (HoistLifeEnd) + HoistCands.insert(Markers.second); + continue; + } + + // Follow the bitcast. + Instruction *MarkerAddr = nullptr; + for (User *U : AI->users()) { + + if (U->stripInBoundsConstantOffsets() == AI) { + SinkLifeStart = false; + HoistLifeEnd = false; + Instruction *Bitcast = cast<Instruction>(U); + Markers = GetLifeTimeMarkers(Bitcast, SinkLifeStart, HoistLifeEnd); + if (Markers.first) { + MarkerAddr = Bitcast; + continue; + } + } + + // Found unknown use of AI. + if (!definedInRegion(Blocks, U)) { + MarkerAddr = nullptr; + break; + } + } + + if (MarkerAddr) { + if (SinkLifeStart) + SinkCands.insert(Markers.first); + if (!definedInRegion(Blocks, MarkerAddr)) + SinkCands.insert(MarkerAddr); + SinkCands.insert(AI); + if (HoistLifeEnd) + HoistCands.insert(Markers.second); + } + } + } +} + +void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs, + const ValueSet &SinkCands) const { + for (BasicBlock *BB : Blocks) { // If a used value is defined outside the region, it's an input. If an // instruction is used outside the region, it's an output. for (Instruction &II : *BB) { for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE; - ++OI) - if (definedInCaller(Blocks, *OI)) - Inputs.insert(*OI); + ++OI) { + Value *V = *OI; + if (!SinkCands.count(V) && definedInCaller(Blocks, V)) + Inputs.insert(V); + } for (User *U : II.users()) if (!definedInRegion(Blocks, U)) { @@ -218,9 +430,7 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { // containing PHI nodes merging values from outside of the region, and a // second that contains all of the code for the block and merges back any // incoming values from inside of the region. - BasicBlock::iterator AfterPHIs = Header->getFirstNonPHI()->getIterator(); - BasicBlock *NewBB = Header->splitBasicBlock(AfterPHIs, - Header->getName()+".ce"); + BasicBlock *NewBB = llvm::SplitBlock(Header, Header->getFirstNonPHI(), DT); // We only want to code extract the second block now, and it becomes the new // header of the region. @@ -229,11 +439,6 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { Blocks.insert(NewBB); Header = NewBB; - // Okay, update dominator sets. The blocks that dominate the new one are the - // blocks that dominate TIBB plus the new block itself. - if (DT) - DT->splitBlock(NewBB); - // Okay, now we need to adjust the PHI nodes and any branches from within the // region to go to the new header block instead of the old header block. if (NumPredsFromRegion) { @@ -248,12 +453,14 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { // Okay, everything within the region is now branching to the right block, we // just have to update the PHI nodes now, inserting PHI nodes into NewBB. + BasicBlock::iterator AfterPHIs; for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) { PHINode *PN = cast<PHINode>(AfterPHIs); // Create a new PHI node in the new region, which has an incoming value // from OldPred of PN. PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion, PN->getName() + ".ce", &NewBB->front()); + PN->replaceAllUsesWith(NewPN); NewPN->addIncoming(PN, OldPred); // Loop over all of the incoming value in PN, moving them to NewPN if they @@ -362,9 +569,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, // "target-features" attribute allowing it to be lowered. // FIXME: This should be changed to check to see if a specific // attribute can not be inherited. - AttributeSet OldFnAttrs = oldFunction->getAttributes().getFnAttributes(); - AttrBuilder AB(OldFnAttrs, AttributeSet::FunctionIndex); - for (auto Attr : AB.td_attrs()) + AttrBuilder AB(oldFunction->getAttributes().getFnAttributes()); + for (const auto &Attr : AB.td_attrs()) newFunction->addFnAttr(Attr.first, Attr.second); newFunction->getBasicBlockList().push_back(newRootNode); @@ -440,8 +646,10 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, // Emit a call to the new function, passing in: *pointer to struct (if // aggregating parameters), or plan inputs and allocated memory for outputs std::vector<Value*> params, StructValues, ReloadOutputs, Reloads; - - LLVMContext &Context = newFunction->getContext(); + + Module *M = newFunction->getParent(); + LLVMContext &Context = M->getContext(); + const DataLayout &DL = M->getDataLayout(); // Add inputs as params, or to be filled into the struct for (Value *input : inputs) @@ -456,8 +664,9 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, StructValues.push_back(output); } else { AllocaInst *alloca = - new AllocaInst(output->getType(), nullptr, output->getName() + ".loc", - &codeReplacer->getParent()->front().front()); + new AllocaInst(output->getType(), DL.getAllocaAddrSpace(), + nullptr, output->getName() + ".loc", + &codeReplacer->getParent()->front().front()); ReloadOutputs.push_back(alloca); params.push_back(alloca); } @@ -473,7 +682,8 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, // Allocate a struct at the beginning of this function StructArgTy = StructType::get(newFunction->getContext(), ArgTypes); - Struct = new AllocaInst(StructArgTy, nullptr, "structArg", + Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr, + "structArg", &codeReplacer->getParent()->front().front()); params.push_back(Struct); @@ -748,7 +958,8 @@ Function *CodeExtractor::extractCodeRegion() { if (!isEligible()) return nullptr; - ValueSet inputs, outputs; + ValueSet inputs, outputs, SinkingCands, HoistingCands; + BasicBlock *CommonExit = nullptr; // Assumption: this is a single-entry code region, and the header is the first // block in the region. @@ -787,8 +998,23 @@ Function *CodeExtractor::extractCodeRegion() { "newFuncRoot"); newFuncRoot->getInstList().push_back(BranchInst::Create(header)); + findAllocas(SinkingCands, HoistingCands, CommonExit); + assert(HoistingCands.empty() || CommonExit); + // Find inputs to, outputs from the code region. - findInputsOutputs(inputs, outputs); + findInputsOutputs(inputs, outputs, SinkingCands); + + // Now sink all instructions which only have non-phi uses inside the region + for (auto *II : SinkingCands) + cast<Instruction>(II)->moveBefore(*newFuncRoot, + newFuncRoot->getFirstInsertionPt()); + + if (!HoistingCands.empty()) { + auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit); + Instruction *TI = HoistToBlock->getTerminator(); + for (auto *II : HoistingCands) + cast<Instruction>(II)->moveBefore(TI); + } // Calculate the exit blocks for the extracted region and the total exit // weights for each of those blocks. |
