diff options
Diffstat (limited to 'lib/Transforms/Utils/CodeExtractor.cpp')
-rw-r--r-- | lib/Transforms/Utils/CodeExtractor.cpp | 496 |
1 files changed, 298 insertions, 198 deletions
diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index 25d4ae583ecc..fa6d3f8ae873 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -1,9 +1,8 @@ //===- CodeExtractor.cpp - Pull code region into a new function -----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -21,6 +20,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" #include "llvm/Analysis/BranchProbabilityInfo.h" @@ -44,6 +44,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" @@ -67,6 +68,7 @@ #include <vector> using namespace llvm; +using namespace llvm::PatternMatch; using ProfileCount = Function::ProfileCount; #define DEBUG_TYPE "code-extractor" @@ -207,6 +209,9 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, llvm_unreachable("Repeated basic blocks in extraction input"); } + LLVM_DEBUG(dbgs() << "Region front block: " << Result.front()->getName() + << '\n'); + for (auto *BB : Result) { if (!isBlockValidForExtraction(*BB, Result, AllowVarArgs, AllowAlloca)) return {}; @@ -224,9 +229,11 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, // the subgraph which is being extracted. for (auto *PBB : predecessors(BB)) if (!Result.count(PBB)) { - LLVM_DEBUG( - dbgs() << "No blocks in this region may have entries from " - "outside the region except for the first block!\n"); + LLVM_DEBUG(dbgs() << "No blocks in this region may have entries from " + "outside the region except for the first block!\n" + << "Problematic source BB: " << BB->getName() << "\n" + << "Problematic destination BB: " << PBB->getName() + << "\n"); return {}; } } @@ -236,18 +243,20 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI, bool AllowVarArgs, - bool AllowAlloca, std::string Suffix) + BranchProbabilityInfo *BPI, AssumptionCache *AC, + bool AllowVarArgs, bool AllowAlloca, + std::string Suffix) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), AllowVarArgs(AllowVarArgs), + BPI(BPI), AC(AC), AllowVarArgs(AllowVarArgs), Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)), Suffix(Suffix) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI, std::string Suffix) + BranchProbabilityInfo *BPI, AssumptionCache *AC, + std::string Suffix) : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), AllowVarArgs(false), + BPI(BPI), AC(AC), AllowVarArgs(false), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT, /* AllowVarArgs */ false, /* AllowAlloca */ false)), @@ -325,7 +334,7 @@ bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers( if (dyn_cast<Constant>(MemAddr)) break; Value *Base = MemAddr->stripInBoundsConstantOffsets(); - if (!dyn_cast<AllocaInst>(Base) || Base == AI) + if (!isa<AllocaInst>(Base) || Base == AI) return false; break; } @@ -401,11 +410,74 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) { return CommonExitBlock; } +// Find the pair of life time markers for address 'Addr' that are either +// defined inside the outline region or can legally be shrinkwrapped into the +// outline region. If there are not other untracked uses of the address, return +// the pair of markers if found; otherwise return a pair of nullptr. +CodeExtractor::LifetimeMarkerInfo +CodeExtractor::getLifetimeMarkers(Instruction *Addr, + BasicBlock *ExitBlock) const { + LifetimeMarkerInfo Info; + + for (User *U : Addr->users()) { + IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U); + if (IntrInst) { + if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) { + // Do not handle the case where Addr has multiple start markers. + if (Info.LifeStart) + return {}; + Info.LifeStart = IntrInst; + } + if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) { + if (Info.LifeEnd) + return {}; + Info.LifeEnd = IntrInst; + } + continue; + } + // Find untracked uses of the address, bail. + if (!definedInRegion(Blocks, U)) + return {}; + } + + if (!Info.LifeStart || !Info.LifeEnd) + return {}; + + Info.SinkLifeStart = !definedInRegion(Blocks, Info.LifeStart); + Info.HoistLifeEnd = !definedInRegion(Blocks, Info.LifeEnd); + // Do legality check. + if ((Info.SinkLifeStart || Info.HoistLifeEnd) && + !isLegalToShrinkwrapLifetimeMarkers(Addr)) + return {}; + + // Check to see if we have a place to do hoisting, if not, bail. + if (Info.HoistLifeEnd && !ExitBlock) + return {}; + + return Info; +} + void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands, BasicBlock *&ExitBlock) const { Function *Func = (*Blocks.begin())->getParent(); ExitBlock = getCommonExitBlock(Blocks); + auto moveOrIgnoreLifetimeMarkers = + [&](const LifetimeMarkerInfo &LMI) -> bool { + if (!LMI.LifeStart) + return false; + if (LMI.SinkLifeStart) { + LLVM_DEBUG(dbgs() << "Sinking lifetime.start: " << *LMI.LifeStart + << "\n"); + SinkCands.insert(LMI.LifeStart); + } + if (LMI.HoistLifeEnd) { + LLVM_DEBUG(dbgs() << "Hoisting lifetime.end: " << *LMI.LifeEnd << "\n"); + HoistCands.insert(LMI.LifeEnd); + } + return true; + }; + for (BasicBlock &BB : *Func) { if (Blocks.count(&BB)) continue; @@ -414,95 +486,52 @@ void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands, if (!AI) continue; - // Find the pair of life time markers for address 'Addr' that are either - // defined inside the outline region or can legally be shrinkwrapped into - // the outline region. If there are not other untracked uses of the - // address, return the pair of markers if found; otherwise return a pair - // of nullptr. - auto GetLifeTimeMarkers = - [&](Instruction *Addr, bool &SinkLifeStart, - bool &HoistLifeEnd) -> std::pair<Instruction *, Instruction *> { - Instruction *LifeStart = nullptr, *LifeEnd = nullptr; - - for (User *U : Addr->users()) { - IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U); - if (IntrInst) { - if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) { - // Do not handle the case where AI has multiple start markers. - if (LifeStart) - return std::make_pair<Instruction *>(nullptr, nullptr); - LifeStart = IntrInst; - } - if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) { - if (LifeEnd) - return std::make_pair<Instruction *>(nullptr, nullptr); - LifeEnd = IntrInst; - } - continue; - } - // Find untracked uses of the address, bail. - if (!definedInRegion(Blocks, U)) - return std::make_pair<Instruction *>(nullptr, nullptr); - } - - if (!LifeStart || !LifeEnd) - return std::make_pair<Instruction *>(nullptr, nullptr); - - SinkLifeStart = !definedInRegion(Blocks, LifeStart); - HoistLifeEnd = !definedInRegion(Blocks, LifeEnd); - // Do legality Check. - if ((SinkLifeStart || HoistLifeEnd) && - !isLegalToShrinkwrapLifetimeMarkers(Addr)) - return std::make_pair<Instruction *>(nullptr, nullptr); - - // Check to see if we have a place to do hoisting, if not, bail. - if (HoistLifeEnd && !ExitBlock) - return std::make_pair<Instruction *>(nullptr, nullptr); - - return std::make_pair(LifeStart, LifeEnd); - }; - - bool SinkLifeStart = false, HoistLifeEnd = false; - auto Markers = GetLifeTimeMarkers(AI, SinkLifeStart, HoistLifeEnd); - - if (Markers.first) { - if (SinkLifeStart) - SinkCands.insert(Markers.first); + LifetimeMarkerInfo MarkerInfo = getLifetimeMarkers(AI, ExitBlock); + bool Moved = moveOrIgnoreLifetimeMarkers(MarkerInfo); + if (Moved) { + LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI << "\n"); SinkCands.insert(AI); - if (HoistLifeEnd) - HoistCands.insert(Markers.second); continue; } - // Follow the bitcast. - Instruction *MarkerAddr = nullptr; + // Follow any bitcasts. + SmallVector<Instruction *, 2> Bitcasts; + SmallVector<LifetimeMarkerInfo, 2> BitcastLifetimeInfo; for (User *U : AI->users()) { if (U->stripInBoundsConstantOffsets() == AI) { - SinkLifeStart = false; - HoistLifeEnd = false; Instruction *Bitcast = cast<Instruction>(U); - Markers = GetLifeTimeMarkers(Bitcast, SinkLifeStart, HoistLifeEnd); - if (Markers.first) { - MarkerAddr = Bitcast; + LifetimeMarkerInfo LMI = getLifetimeMarkers(Bitcast, ExitBlock); + if (LMI.LifeStart) { + Bitcasts.push_back(Bitcast); + BitcastLifetimeInfo.push_back(LMI); continue; } } // Found unknown use of AI. if (!definedInRegion(Blocks, U)) { - MarkerAddr = nullptr; + Bitcasts.clear(); break; } } - if (MarkerAddr) { - if (SinkLifeStart) - SinkCands.insert(Markers.first); - if (!definedInRegion(Blocks, MarkerAddr)) - SinkCands.insert(MarkerAddr); - SinkCands.insert(AI); - if (HoistLifeEnd) - HoistCands.insert(Markers.second); + // Either no bitcasts reference the alloca or there are unknown uses. + if (Bitcasts.empty()) + continue; + + LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI << "\n"); + SinkCands.insert(AI); + for (unsigned I = 0, E = Bitcasts.size(); I != E; ++I) { + Instruction *BitcastAddr = Bitcasts[I]; + const LifetimeMarkerInfo &LMI = BitcastLifetimeInfo[I]; + assert(LMI.LifeStart && + "Unsafe to sink bitcast without lifetime markers"); + moveOrIgnoreLifetimeMarkers(LMI); + if (!definedInRegion(Blocks, BitcastAddr)) { + LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr + << "\n"); + SinkCands.insert(BitcastAddr); + } } } } @@ -780,6 +809,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::NoBuiltin: case Attribute::NoCapture: case Attribute::NoReturn: + case Attribute::NoSync: case Attribute::None: case Attribute::NonNull: case Attribute::ReadNone: @@ -792,8 +822,10 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::StructRet: case Attribute::SwiftError: case Attribute::SwiftSelf: + case Attribute::WillReturn: case Attribute::WriteOnly: case Attribute::ZExt: + case Attribute::ImmArg: case Attribute::EndAttrKinds: continue; // Those attributes should be safe to propagate to the extracted function. @@ -803,6 +835,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::InlineHint: case Attribute::MinSize: case Attribute::NoDuplicate: + case Attribute::NoFree: case Attribute::NoImplicitFloat: case Attribute::NoInline: case Attribute::NonLazyBind: @@ -817,6 +850,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::SanitizeMemory: case Attribute::SanitizeThread: case Attribute::SanitizeHWAddress: + case Attribute::SanitizeMemTag: case Attribute::SpeculativeLoadHardening: case Attribute::StackProtect: case Attribute::StackProtectReq: @@ -845,7 +879,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, Instruction *TI = newFunction->begin()->getTerminator(); GetElementPtrInst *GEP = GetElementPtrInst::Create( StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI); - RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI); + RewriteVal = new LoadInst(StructTy->getElementType(i), GEP, + "loadgep_" + inputs[i]->getName(), TI); } else RewriteVal = &*AI++; @@ -880,6 +915,88 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, return newFunction; } +/// Erase lifetime.start markers which reference inputs to the extraction +/// region, and insert the referenced memory into \p LifetimesStart. +/// +/// The extraction region is defined by a set of blocks (\p Blocks), and a set +/// of allocas which will be moved from the caller function into the extracted +/// function (\p SunkAllocas). +static void eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks, + const SetVector<Value *> &SunkAllocas, + SetVector<Value *> &LifetimesStart) { + for (BasicBlock *BB : Blocks) { + for (auto It = BB->begin(), End = BB->end(); It != End;) { + auto *II = dyn_cast<IntrinsicInst>(&*It); + ++It; + if (!II || !II->isLifetimeStartOrEnd()) + continue; + + // Get the memory operand of the lifetime marker. If the underlying + // object is a sunk alloca, or is otherwise defined in the extraction + // region, the lifetime marker must not be erased. + Value *Mem = II->getOperand(1)->stripInBoundsOffsets(); + if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem)) + continue; + + if (II->getIntrinsicID() == Intrinsic::lifetime_start) + LifetimesStart.insert(Mem); + II->eraseFromParent(); + } + } +} + +/// Insert lifetime start/end markers surrounding the call to the new function +/// for objects defined in the caller. +static void insertLifetimeMarkersSurroundingCall( + Module *M, ArrayRef<Value *> LifetimesStart, ArrayRef<Value *> LifetimesEnd, + CallInst *TheCall) { + LLVMContext &Ctx = M->getContext(); + auto Int8PtrTy = Type::getInt8PtrTy(Ctx); + auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1); + Instruction *Term = TheCall->getParent()->getTerminator(); + + // The memory argument to a lifetime marker must be a i8*. Cache any bitcasts + // needed to satisfy this requirement so they may be reused. + DenseMap<Value *, Value *> Bitcasts; + + // Emit lifetime markers for the pointers given in \p Objects. Insert the + // markers before the call if \p InsertBefore, and after the call otherwise. + auto insertMarkers = [&](Function *MarkerFunc, ArrayRef<Value *> Objects, + bool InsertBefore) { + for (Value *Mem : Objects) { + assert((!isa<Instruction>(Mem) || cast<Instruction>(Mem)->getFunction() == + TheCall->getFunction()) && + "Input memory not defined in original function"); + Value *&MemAsI8Ptr = Bitcasts[Mem]; + if (!MemAsI8Ptr) { + if (Mem->getType() == Int8PtrTy) + MemAsI8Ptr = Mem; + else + MemAsI8Ptr = + CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall); + } + + auto Marker = CallInst::Create(MarkerFunc, {NegativeOne, MemAsI8Ptr}); + if (InsertBefore) + Marker->insertBefore(TheCall); + else + Marker->insertBefore(Term); + } + }; + + if (!LifetimesStart.empty()) { + auto StartFn = llvm::Intrinsic::getDeclaration( + M, llvm::Intrinsic::lifetime_start, Int8PtrTy); + insertMarkers(StartFn, LifetimesStart, /*InsertBefore=*/true); + } + + if (!LifetimesEnd.empty()) { + auto EndFn = llvm::Intrinsic::getDeclaration( + M, llvm::Intrinsic::lifetime_end, Int8PtrTy); + insertMarkers(EndFn, LifetimesEnd, /*InsertBefore=*/false); + } +} + /// emitCallAndSwitchStatement - This method sets up the caller side by adding /// the call instruction, splitting any PHI nodes in the header block as /// necessary. @@ -897,11 +1014,18 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, CallInst *call = nullptr; // Add inputs as params, or to be filled into the struct - for (Value *input : inputs) + unsigned ArgNo = 0; + SmallVector<unsigned, 1> SwiftErrorArgs; + for (Value *input : inputs) { if (AggregateArgs) StructValues.push_back(input); - else + else { params.push_back(input); + if (input->isSwiftError()) + SwiftErrorArgs.push_back(ArgNo); + } + ++ArgNo; + } // Create allocas for the outputs for (Value *output : outputs) { @@ -957,13 +1081,18 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, } codeReplacer->getInstList().push_back(call); + // Set swifterror parameter attributes. + for (unsigned SwiftErrArgNo : SwiftErrorArgs) { + call->addParamAttr(SwiftErrArgNo, Attribute::SwiftError); + newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError); + } + Function::arg_iterator OutputArgBegin = newFunction->arg_begin(); unsigned FirstOut = inputs.size(); if (!AggregateArgs) std::advance(OutputArgBegin, inputs.size()); // Reload the outputs passed in by reference. - Function::arg_iterator OAI = OutputArgBegin; for (unsigned i = 0, e = outputs.size(); i != e; ++i) { Value *Output = nullptr; if (AggregateArgs) { @@ -977,7 +1106,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, } else { Output = ReloadOutputs[i]; } - LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload"); + LoadInst *load = new LoadInst(outputs[i]->getType(), Output, + outputs[i]->getName() + ".reload"); Reloads.push_back(load); codeReplacer->getInstList().push_back(load); std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end()); @@ -986,40 +1116,6 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, if (!Blocks.count(inst->getParent())) inst->replaceUsesOfWith(outputs[i], load); } - - // Store to argument right after the definition of output value. - auto *OutI = dyn_cast<Instruction>(outputs[i]); - if (!OutI) - continue; - - // Find proper insertion point. - BasicBlock::iterator InsertPt; - // In case OutI is an invoke, we insert the store at the beginning in the - // 'normal destination' BB. Otherwise we insert the store right after OutI. - if (auto *InvokeI = dyn_cast<InvokeInst>(OutI)) - InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt(); - else if (auto *Phi = dyn_cast<PHINode>(OutI)) - InsertPt = Phi->getParent()->getFirstInsertionPt(); - else - InsertPt = std::next(OutI->getIterator()); - - assert(OAI != newFunction->arg_end() && - "Number of output arguments should match " - "the amount of defined values"); - if (AggregateArgs) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), &*InsertPt); - new StoreInst(outputs[i], GEP, &*InsertPt); - // Since there should be only one struct argument aggregating - // all the output values, we shouldn't increment OAI, which always - // points to the struct argument, in this case. - } else { - new StoreInst(outputs[i], &*OAI, &*InsertPt); - ++OAI; - } } // Now we can emit a switch statement using the call as a value. @@ -1075,6 +1171,50 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, } } + // Store the arguments right after the definition of output value. + // This should be proceeded after creating exit stubs to be ensure that invoke + // result restore will be placed in the outlined function. + Function::arg_iterator OAI = OutputArgBegin; + for (unsigned i = 0, e = outputs.size(); i != e; ++i) { + auto *OutI = dyn_cast<Instruction>(outputs[i]); + if (!OutI) + continue; + + // Find proper insertion point. + BasicBlock::iterator InsertPt; + // In case OutI is an invoke, we insert the store at the beginning in the + // 'normal destination' BB. Otherwise we insert the store right after OutI. + if (auto *InvokeI = dyn_cast<InvokeInst>(OutI)) + InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt(); + else if (auto *Phi = dyn_cast<PHINode>(OutI)) + InsertPt = Phi->getParent()->getFirstInsertionPt(); + else + InsertPt = std::next(OutI->getIterator()); + + Instruction *InsertBefore = &*InsertPt; + assert((InsertBefore->getFunction() == newFunction || + Blocks.count(InsertBefore->getParent())) && + "InsertPt should be in new function"); + assert(OAI != newFunction->arg_end() && + "Number of output arguments should match " + "the amount of defined values"); + if (AggregateArgs) { + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), + InsertBefore); + new StoreInst(outputs[i], GEP, InsertBefore); + // Since there should be only one struct argument aggregating + // all the output values, we shouldn't increment OAI, which always + // points to the struct argument, in this case. + } else { + new StoreInst(outputs[i], &*OAI, InsertBefore); + ++OAI; + } + } + // Now that we've done the deed, simplify the switch instruction. Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType(); switch (NumExitBlocks) { @@ -1119,6 +1259,10 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, break; } + // Insert lifetime markers around the reloads of any output values. The + // allocas output values are stored in are only in-use in the codeRepl block. + insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call); + return call; } @@ -1133,6 +1277,13 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) { // Insert this basic block into the new function newBlocks.push_back(Block); + + // Remove @llvm.assume calls that were moved to the new function from the + // old function's assumption cache. + if (AC) + for (auto &I : *Block) + if (match(&I, m_Intrinsic<Intrinsic::assume>())) + AC->unregisterAssumption(cast<CallInst>(&I)); } } @@ -1181,71 +1332,6 @@ void CodeExtractor::calculateNewCallTerminatorWeights( MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); } -/// Scan the extraction region for lifetime markers which reference inputs. -/// Erase these markers. Return the inputs which were referenced. -/// -/// The extraction region is defined by a set of blocks (\p Blocks), and a set -/// of allocas which will be moved from the caller function into the extracted -/// function (\p SunkAllocas). -static SetVector<Value *> -eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks, - const SetVector<Value *> &SunkAllocas) { - SetVector<Value *> InputObjectsWithLifetime; - for (BasicBlock *BB : Blocks) { - for (auto It = BB->begin(), End = BB->end(); It != End;) { - auto *II = dyn_cast<IntrinsicInst>(&*It); - ++It; - if (!II || !II->isLifetimeStartOrEnd()) - continue; - - // Get the memory operand of the lifetime marker. If the underlying - // object is a sunk alloca, or is otherwise defined in the extraction - // region, the lifetime marker must not be erased. - Value *Mem = II->getOperand(1)->stripInBoundsOffsets(); - if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem)) - continue; - - InputObjectsWithLifetime.insert(Mem); - II->eraseFromParent(); - } - } - return InputObjectsWithLifetime; -} - -/// Insert lifetime start/end markers surrounding the call to the new function -/// for objects defined in the caller. -static void insertLifetimeMarkersSurroundingCall( - Module *M, const SetVector<Value *> &InputObjectsWithLifetime, - CallInst *TheCall) { - if (InputObjectsWithLifetime.empty()) - return; - - LLVMContext &Ctx = M->getContext(); - auto Int8PtrTy = Type::getInt8PtrTy(Ctx); - auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1); - auto LifetimeStartFn = llvm::Intrinsic::getDeclaration( - M, llvm::Intrinsic::lifetime_start, Int8PtrTy); - auto LifetimeEndFn = llvm::Intrinsic::getDeclaration( - M, llvm::Intrinsic::lifetime_end, Int8PtrTy); - for (Value *Mem : InputObjectsWithLifetime) { - assert((!isa<Instruction>(Mem) || - cast<Instruction>(Mem)->getFunction() == TheCall->getFunction()) && - "Input memory not defined in original function"); - Value *MemAsI8Ptr = nullptr; - if (Mem->getType() == Int8PtrTy) - MemAsI8Ptr = Mem; - else - MemAsI8Ptr = - CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall); - - auto StartMarker = - CallInst::Create(LifetimeStartFn, {NegativeOne, MemAsI8Ptr}); - StartMarker->insertBefore(TheCall); - auto EndMarker = CallInst::Create(LifetimeEndFn, {NegativeOne, MemAsI8Ptr}); - EndMarker->insertAfter(TheCall); - } -} - Function *CodeExtractor::extractCodeRegion() { if (!isEligible()) return nullptr; @@ -1348,10 +1434,24 @@ Function *CodeExtractor::extractCodeRegion() { // Find inputs to, outputs from the code region. findInputsOutputs(inputs, outputs, SinkingCands); - // Now sink all instructions which only have non-phi uses inside the region - for (auto *II : SinkingCands) - cast<Instruction>(II)->moveBefore(*newFuncRoot, - newFuncRoot->getFirstInsertionPt()); + // Now sink all instructions which only have non-phi uses inside the region. + // Group the allocas at the start of the block, so that any bitcast uses of + // the allocas are well-defined. + AllocaInst *FirstSunkAlloca = nullptr; + for (auto *II : SinkingCands) { + if (auto *AI = dyn_cast<AllocaInst>(II)) { + AI->moveBefore(*newFuncRoot, newFuncRoot->getFirstInsertionPt()); + if (!FirstSunkAlloca) + FirstSunkAlloca = AI; + } + } + assert((SinkingCands.empty() || FirstSunkAlloca) && + "Did not expect a sink candidate without any allocas"); + for (auto *II : SinkingCands) { + if (!isa<AllocaInst>(II)) { + cast<Instruction>(II)->moveAfter(FirstSunkAlloca); + } + } if (!HoistingCands.empty()) { auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit); @@ -1361,11 +1461,11 @@ Function *CodeExtractor::extractCodeRegion() { } // Collect objects which are inputs to the extraction region and also - // referenced by lifetime start/end markers within it. The effects of these + // referenced by lifetime start markers within it. The effects of these // markers must be replicated in the calling function to prevent the stack // coloring pass from merging slots which store input objects. - ValueSet InputObjectsWithLifetime = - eraseLifetimeMarkersOnInputs(Blocks, SinkingCands); + ValueSet LifetimesStart; + eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart); // Construct new function based on inputs/outputs & add allocas for all defs. Function *newFunction = @@ -1388,8 +1488,8 @@ Function *CodeExtractor::extractCodeRegion() { // Replicate the effects of any lifetime start/end markers which referenced // input objects in the extraction region by placing markers around the call. - insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), - InputObjectsWithLifetime, TheCall); + insertLifetimeMarkersSurroundingCall( + oldFunction->getParent(), LifetimesStart.getArrayRef(), {}, TheCall); // Propagate personality info to the new function if there is one. if (oldFunction->hasPersonalityFn()) |