aboutsummaryrefslogtreecommitdiff
path: root/lib/Transforms/Utils/CodeExtractor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/Utils/CodeExtractor.cpp')
-rw-r--r--lib/Transforms/Utils/CodeExtractor.cpp496
1 files changed, 298 insertions, 198 deletions
diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp
index 25d4ae583ecc..fa6d3f8ae873 100644
--- a/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/lib/Transforms/Utils/CodeExtractor.cpp
@@ -1,9 +1,8 @@
//===- CodeExtractor.cpp - Pull code region into a new function -----------===//
//
-// The LLVM Compiler Infrastructure
-//
-// This file is distributed under the University of Illinois Open Source
-// License. See LICENSE.TXT for details.
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
@@ -21,6 +20,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
#include "llvm/Analysis/BranchProbabilityInfo.h"
@@ -44,6 +44,7 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
@@ -67,6 +68,7 @@
#include <vector>
using namespace llvm;
+using namespace llvm::PatternMatch;
using ProfileCount = Function::ProfileCount;
#define DEBUG_TYPE "code-extractor"
@@ -207,6 +209,9 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
llvm_unreachable("Repeated basic blocks in extraction input");
}
+ LLVM_DEBUG(dbgs() << "Region front block: " << Result.front()->getName()
+ << '\n');
+
for (auto *BB : Result) {
if (!isBlockValidForExtraction(*BB, Result, AllowVarArgs, AllowAlloca))
return {};
@@ -224,9 +229,11 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
// the subgraph which is being extracted.
for (auto *PBB : predecessors(BB))
if (!Result.count(PBB)) {
- LLVM_DEBUG(
- dbgs() << "No blocks in this region may have entries from "
- "outside the region except for the first block!\n");
+ LLVM_DEBUG(dbgs() << "No blocks in this region may have entries from "
+ "outside the region except for the first block!\n"
+ << "Problematic source BB: " << BB->getName() << "\n"
+ << "Problematic destination BB: " << PBB->getName()
+ << "\n");
return {};
}
}
@@ -236,18 +243,20 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
bool AggregateArgs, BlockFrequencyInfo *BFI,
- BranchProbabilityInfo *BPI, bool AllowVarArgs,
- bool AllowAlloca, std::string Suffix)
+ BranchProbabilityInfo *BPI, AssumptionCache *AC,
+ bool AllowVarArgs, bool AllowAlloca,
+ std::string Suffix)
: DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
- BPI(BPI), AllowVarArgs(AllowVarArgs),
+ BPI(BPI), AC(AC), AllowVarArgs(AllowVarArgs),
Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
Suffix(Suffix) {}
CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
BlockFrequencyInfo *BFI,
- BranchProbabilityInfo *BPI, std::string Suffix)
+ BranchProbabilityInfo *BPI, AssumptionCache *AC,
+ std::string Suffix)
: DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
- BPI(BPI), AllowVarArgs(false),
+ BPI(BPI), AC(AC), AllowVarArgs(false),
Blocks(buildExtractionBlockSet(L.getBlocks(), &DT,
/* AllowVarArgs */ false,
/* AllowAlloca */ false)),
@@ -325,7 +334,7 @@ bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
if (dyn_cast<Constant>(MemAddr))
break;
Value *Base = MemAddr->stripInBoundsConstantOffsets();
- if (!dyn_cast<AllocaInst>(Base) || Base == AI)
+ if (!isa<AllocaInst>(Base) || Base == AI)
return false;
break;
}
@@ -401,11 +410,74 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
return CommonExitBlock;
}
+// Find the pair of life time markers for address 'Addr' that are either
+// defined inside the outline region or can legally be shrinkwrapped into the
+// outline region. If there are not other untracked uses of the address, return
+// the pair of markers if found; otherwise return a pair of nullptr.
+CodeExtractor::LifetimeMarkerInfo
+CodeExtractor::getLifetimeMarkers(Instruction *Addr,
+ BasicBlock *ExitBlock) const {
+ LifetimeMarkerInfo Info;
+
+ for (User *U : Addr->users()) {
+ IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
+ if (IntrInst) {
+ if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
+ // Do not handle the case where Addr has multiple start markers.
+ if (Info.LifeStart)
+ return {};
+ Info.LifeStart = IntrInst;
+ }
+ if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
+ if (Info.LifeEnd)
+ return {};
+ Info.LifeEnd = IntrInst;
+ }
+ continue;
+ }
+ // Find untracked uses of the address, bail.
+ if (!definedInRegion(Blocks, U))
+ return {};
+ }
+
+ if (!Info.LifeStart || !Info.LifeEnd)
+ return {};
+
+ Info.SinkLifeStart = !definedInRegion(Blocks, Info.LifeStart);
+ Info.HoistLifeEnd = !definedInRegion(Blocks, Info.LifeEnd);
+ // Do legality check.
+ if ((Info.SinkLifeStart || Info.HoistLifeEnd) &&
+ !isLegalToShrinkwrapLifetimeMarkers(Addr))
+ return {};
+
+ // Check to see if we have a place to do hoisting, if not, bail.
+ if (Info.HoistLifeEnd && !ExitBlock)
+ return {};
+
+ return Info;
+}
+
void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands,
BasicBlock *&ExitBlock) const {
Function *Func = (*Blocks.begin())->getParent();
ExitBlock = getCommonExitBlock(Blocks);
+ auto moveOrIgnoreLifetimeMarkers =
+ [&](const LifetimeMarkerInfo &LMI) -> bool {
+ if (!LMI.LifeStart)
+ return false;
+ if (LMI.SinkLifeStart) {
+ LLVM_DEBUG(dbgs() << "Sinking lifetime.start: " << *LMI.LifeStart
+ << "\n");
+ SinkCands.insert(LMI.LifeStart);
+ }
+ if (LMI.HoistLifeEnd) {
+ LLVM_DEBUG(dbgs() << "Hoisting lifetime.end: " << *LMI.LifeEnd << "\n");
+ HoistCands.insert(LMI.LifeEnd);
+ }
+ return true;
+ };
+
for (BasicBlock &BB : *Func) {
if (Blocks.count(&BB))
continue;
@@ -414,95 +486,52 @@ void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands,
if (!AI)
continue;
- // Find the pair of life time markers for address 'Addr' that are either
- // defined inside the outline region or can legally be shrinkwrapped into
- // the outline region. If there are not other untracked uses of the
- // address, return the pair of markers if found; otherwise return a pair
- // of nullptr.
- auto GetLifeTimeMarkers =
- [&](Instruction *Addr, bool &SinkLifeStart,
- bool &HoistLifeEnd) -> std::pair<Instruction *, Instruction *> {
- Instruction *LifeStart = nullptr, *LifeEnd = nullptr;
-
- for (User *U : Addr->users()) {
- IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
- if (IntrInst) {
- if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
- // Do not handle the case where AI has multiple start markers.
- if (LifeStart)
- return std::make_pair<Instruction *>(nullptr, nullptr);
- LifeStart = IntrInst;
- }
- if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
- if (LifeEnd)
- return std::make_pair<Instruction *>(nullptr, nullptr);
- LifeEnd = IntrInst;
- }
- continue;
- }
- // Find untracked uses of the address, bail.
- if (!definedInRegion(Blocks, U))
- return std::make_pair<Instruction *>(nullptr, nullptr);
- }
-
- if (!LifeStart || !LifeEnd)
- return std::make_pair<Instruction *>(nullptr, nullptr);
-
- SinkLifeStart = !definedInRegion(Blocks, LifeStart);
- HoistLifeEnd = !definedInRegion(Blocks, LifeEnd);
- // Do legality Check.
- if ((SinkLifeStart || HoistLifeEnd) &&
- !isLegalToShrinkwrapLifetimeMarkers(Addr))
- return std::make_pair<Instruction *>(nullptr, nullptr);
-
- // Check to see if we have a place to do hoisting, if not, bail.
- if (HoistLifeEnd && !ExitBlock)
- return std::make_pair<Instruction *>(nullptr, nullptr);
-
- return std::make_pair(LifeStart, LifeEnd);
- };
-
- bool SinkLifeStart = false, HoistLifeEnd = false;
- auto Markers = GetLifeTimeMarkers(AI, SinkLifeStart, HoistLifeEnd);
-
- if (Markers.first) {
- if (SinkLifeStart)
- SinkCands.insert(Markers.first);
+ LifetimeMarkerInfo MarkerInfo = getLifetimeMarkers(AI, ExitBlock);
+ bool Moved = moveOrIgnoreLifetimeMarkers(MarkerInfo);
+ if (Moved) {
+ LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI << "\n");
SinkCands.insert(AI);
- if (HoistLifeEnd)
- HoistCands.insert(Markers.second);
continue;
}
- // Follow the bitcast.
- Instruction *MarkerAddr = nullptr;
+ // Follow any bitcasts.
+ SmallVector<Instruction *, 2> Bitcasts;
+ SmallVector<LifetimeMarkerInfo, 2> BitcastLifetimeInfo;
for (User *U : AI->users()) {
if (U->stripInBoundsConstantOffsets() == AI) {
- SinkLifeStart = false;
- HoistLifeEnd = false;
Instruction *Bitcast = cast<Instruction>(U);
- Markers = GetLifeTimeMarkers(Bitcast, SinkLifeStart, HoistLifeEnd);
- if (Markers.first) {
- MarkerAddr = Bitcast;
+ LifetimeMarkerInfo LMI = getLifetimeMarkers(Bitcast, ExitBlock);
+ if (LMI.LifeStart) {
+ Bitcasts.push_back(Bitcast);
+ BitcastLifetimeInfo.push_back(LMI);
continue;
}
}
// Found unknown use of AI.
if (!definedInRegion(Blocks, U)) {
- MarkerAddr = nullptr;
+ Bitcasts.clear();
break;
}
}
- if (MarkerAddr) {
- if (SinkLifeStart)
- SinkCands.insert(Markers.first);
- if (!definedInRegion(Blocks, MarkerAddr))
- SinkCands.insert(MarkerAddr);
- SinkCands.insert(AI);
- if (HoistLifeEnd)
- HoistCands.insert(Markers.second);
+ // Either no bitcasts reference the alloca or there are unknown uses.
+ if (Bitcasts.empty())
+ continue;
+
+ LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI << "\n");
+ SinkCands.insert(AI);
+ for (unsigned I = 0, E = Bitcasts.size(); I != E; ++I) {
+ Instruction *BitcastAddr = Bitcasts[I];
+ const LifetimeMarkerInfo &LMI = BitcastLifetimeInfo[I];
+ assert(LMI.LifeStart &&
+ "Unsafe to sink bitcast without lifetime markers");
+ moveOrIgnoreLifetimeMarkers(LMI);
+ if (!definedInRegion(Blocks, BitcastAddr)) {
+ LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr
+ << "\n");
+ SinkCands.insert(BitcastAddr);
+ }
}
}
}
@@ -780,6 +809,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
case Attribute::NoBuiltin:
case Attribute::NoCapture:
case Attribute::NoReturn:
+ case Attribute::NoSync:
case Attribute::None:
case Attribute::NonNull:
case Attribute::ReadNone:
@@ -792,8 +822,10 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
case Attribute::StructRet:
case Attribute::SwiftError:
case Attribute::SwiftSelf:
+ case Attribute::WillReturn:
case Attribute::WriteOnly:
case Attribute::ZExt:
+ case Attribute::ImmArg:
case Attribute::EndAttrKinds:
continue;
// Those attributes should be safe to propagate to the extracted function.
@@ -803,6 +835,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
case Attribute::InlineHint:
case Attribute::MinSize:
case Attribute::NoDuplicate:
+ case Attribute::NoFree:
case Attribute::NoImplicitFloat:
case Attribute::NoInline:
case Attribute::NonLazyBind:
@@ -817,6 +850,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
case Attribute::SanitizeMemory:
case Attribute::SanitizeThread:
case Attribute::SanitizeHWAddress:
+ case Attribute::SanitizeMemTag:
case Attribute::SpeculativeLoadHardening:
case Attribute::StackProtect:
case Attribute::StackProtectReq:
@@ -845,7 +879,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
Instruction *TI = newFunction->begin()->getTerminator();
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
- RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI);
+ RewriteVal = new LoadInst(StructTy->getElementType(i), GEP,
+ "loadgep_" + inputs[i]->getName(), TI);
} else
RewriteVal = &*AI++;
@@ -880,6 +915,88 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
return newFunction;
}
+/// Erase lifetime.start markers which reference inputs to the extraction
+/// region, and insert the referenced memory into \p LifetimesStart.
+///
+/// The extraction region is defined by a set of blocks (\p Blocks), and a set
+/// of allocas which will be moved from the caller function into the extracted
+/// function (\p SunkAllocas).
+static void eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks,
+ const SetVector<Value *> &SunkAllocas,
+ SetVector<Value *> &LifetimesStart) {
+ for (BasicBlock *BB : Blocks) {
+ for (auto It = BB->begin(), End = BB->end(); It != End;) {
+ auto *II = dyn_cast<IntrinsicInst>(&*It);
+ ++It;
+ if (!II || !II->isLifetimeStartOrEnd())
+ continue;
+
+ // Get the memory operand of the lifetime marker. If the underlying
+ // object is a sunk alloca, or is otherwise defined in the extraction
+ // region, the lifetime marker must not be erased.
+ Value *Mem = II->getOperand(1)->stripInBoundsOffsets();
+ if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem))
+ continue;
+
+ if (II->getIntrinsicID() == Intrinsic::lifetime_start)
+ LifetimesStart.insert(Mem);
+ II->eraseFromParent();
+ }
+ }
+}
+
+/// Insert lifetime start/end markers surrounding the call to the new function
+/// for objects defined in the caller.
+static void insertLifetimeMarkersSurroundingCall(
+ Module *M, ArrayRef<Value *> LifetimesStart, ArrayRef<Value *> LifetimesEnd,
+ CallInst *TheCall) {
+ LLVMContext &Ctx = M->getContext();
+ auto Int8PtrTy = Type::getInt8PtrTy(Ctx);
+ auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1);
+ Instruction *Term = TheCall->getParent()->getTerminator();
+
+ // The memory argument to a lifetime marker must be a i8*. Cache any bitcasts
+ // needed to satisfy this requirement so they may be reused.
+ DenseMap<Value *, Value *> Bitcasts;
+
+ // Emit lifetime markers for the pointers given in \p Objects. Insert the
+ // markers before the call if \p InsertBefore, and after the call otherwise.
+ auto insertMarkers = [&](Function *MarkerFunc, ArrayRef<Value *> Objects,
+ bool InsertBefore) {
+ for (Value *Mem : Objects) {
+ assert((!isa<Instruction>(Mem) || cast<Instruction>(Mem)->getFunction() ==
+ TheCall->getFunction()) &&
+ "Input memory not defined in original function");
+ Value *&MemAsI8Ptr = Bitcasts[Mem];
+ if (!MemAsI8Ptr) {
+ if (Mem->getType() == Int8PtrTy)
+ MemAsI8Ptr = Mem;
+ else
+ MemAsI8Ptr =
+ CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall);
+ }
+
+ auto Marker = CallInst::Create(MarkerFunc, {NegativeOne, MemAsI8Ptr});
+ if (InsertBefore)
+ Marker->insertBefore(TheCall);
+ else
+ Marker->insertBefore(Term);
+ }
+ };
+
+ if (!LifetimesStart.empty()) {
+ auto StartFn = llvm::Intrinsic::getDeclaration(
+ M, llvm::Intrinsic::lifetime_start, Int8PtrTy);
+ insertMarkers(StartFn, LifetimesStart, /*InsertBefore=*/true);
+ }
+
+ if (!LifetimesEnd.empty()) {
+ auto EndFn = llvm::Intrinsic::getDeclaration(
+ M, llvm::Intrinsic::lifetime_end, Int8PtrTy);
+ insertMarkers(EndFn, LifetimesEnd, /*InsertBefore=*/false);
+ }
+}
+
/// emitCallAndSwitchStatement - This method sets up the caller side by adding
/// the call instruction, splitting any PHI nodes in the header block as
/// necessary.
@@ -897,11 +1014,18 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
CallInst *call = nullptr;
// Add inputs as params, or to be filled into the struct
- for (Value *input : inputs)
+ unsigned ArgNo = 0;
+ SmallVector<unsigned, 1> SwiftErrorArgs;
+ for (Value *input : inputs) {
if (AggregateArgs)
StructValues.push_back(input);
- else
+ else {
params.push_back(input);
+ if (input->isSwiftError())
+ SwiftErrorArgs.push_back(ArgNo);
+ }
+ ++ArgNo;
+ }
// Create allocas for the outputs
for (Value *output : outputs) {
@@ -957,13 +1081,18 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
}
codeReplacer->getInstList().push_back(call);
+ // Set swifterror parameter attributes.
+ for (unsigned SwiftErrArgNo : SwiftErrorArgs) {
+ call->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
+ newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
+ }
+
Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
unsigned FirstOut = inputs.size();
if (!AggregateArgs)
std::advance(OutputArgBegin, inputs.size());
// Reload the outputs passed in by reference.
- Function::arg_iterator OAI = OutputArgBegin;
for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
Value *Output = nullptr;
if (AggregateArgs) {
@@ -977,7 +1106,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
} else {
Output = ReloadOutputs[i];
}
- LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload");
+ LoadInst *load = new LoadInst(outputs[i]->getType(), Output,
+ outputs[i]->getName() + ".reload");
Reloads.push_back(load);
codeReplacer->getInstList().push_back(load);
std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end());
@@ -986,40 +1116,6 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
if (!Blocks.count(inst->getParent()))
inst->replaceUsesOfWith(outputs[i], load);
}
-
- // Store to argument right after the definition of output value.
- auto *OutI = dyn_cast<Instruction>(outputs[i]);
- if (!OutI)
- continue;
-
- // Find proper insertion point.
- BasicBlock::iterator InsertPt;
- // In case OutI is an invoke, we insert the store at the beginning in the
- // 'normal destination' BB. Otherwise we insert the store right after OutI.
- if (auto *InvokeI = dyn_cast<InvokeInst>(OutI))
- InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
- else if (auto *Phi = dyn_cast<PHINode>(OutI))
- InsertPt = Phi->getParent()->getFirstInsertionPt();
- else
- InsertPt = std::next(OutI->getIterator());
-
- assert(OAI != newFunction->arg_end() &&
- "Number of output arguments should match "
- "the amount of defined values");
- if (AggregateArgs) {
- Value *Idx[2];
- Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
- Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
- GetElementPtrInst *GEP = GetElementPtrInst::Create(
- StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), &*InsertPt);
- new StoreInst(outputs[i], GEP, &*InsertPt);
- // Since there should be only one struct argument aggregating
- // all the output values, we shouldn't increment OAI, which always
- // points to the struct argument, in this case.
- } else {
- new StoreInst(outputs[i], &*OAI, &*InsertPt);
- ++OAI;
- }
}
// Now we can emit a switch statement using the call as a value.
@@ -1075,6 +1171,50 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
}
}
+ // Store the arguments right after the definition of output value.
+ // This should be proceeded after creating exit stubs to be ensure that invoke
+ // result restore will be placed in the outlined function.
+ Function::arg_iterator OAI = OutputArgBegin;
+ for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
+ auto *OutI = dyn_cast<Instruction>(outputs[i]);
+ if (!OutI)
+ continue;
+
+ // Find proper insertion point.
+ BasicBlock::iterator InsertPt;
+ // In case OutI is an invoke, we insert the store at the beginning in the
+ // 'normal destination' BB. Otherwise we insert the store right after OutI.
+ if (auto *InvokeI = dyn_cast<InvokeInst>(OutI))
+ InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
+ else if (auto *Phi = dyn_cast<PHINode>(OutI))
+ InsertPt = Phi->getParent()->getFirstInsertionPt();
+ else
+ InsertPt = std::next(OutI->getIterator());
+
+ Instruction *InsertBefore = &*InsertPt;
+ assert((InsertBefore->getFunction() == newFunction ||
+ Blocks.count(InsertBefore->getParent())) &&
+ "InsertPt should be in new function");
+ assert(OAI != newFunction->arg_end() &&
+ "Number of output arguments should match "
+ "the amount of defined values");
+ if (AggregateArgs) {
+ Value *Idx[2];
+ Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
+ GetElementPtrInst *GEP = GetElementPtrInst::Create(
+ StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(),
+ InsertBefore);
+ new StoreInst(outputs[i], GEP, InsertBefore);
+ // Since there should be only one struct argument aggregating
+ // all the output values, we shouldn't increment OAI, which always
+ // points to the struct argument, in this case.
+ } else {
+ new StoreInst(outputs[i], &*OAI, InsertBefore);
+ ++OAI;
+ }
+ }
+
// Now that we've done the deed, simplify the switch instruction.
Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
switch (NumExitBlocks) {
@@ -1119,6 +1259,10 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
break;
}
+ // Insert lifetime markers around the reloads of any output values. The
+ // allocas output values are stored in are only in-use in the codeRepl block.
+ insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call);
+
return call;
}
@@ -1133,6 +1277,13 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) {
// Insert this basic block into the new function
newBlocks.push_back(Block);
+
+ // Remove @llvm.assume calls that were moved to the new function from the
+ // old function's assumption cache.
+ if (AC)
+ for (auto &I : *Block)
+ if (match(&I, m_Intrinsic<Intrinsic::assume>()))
+ AC->unregisterAssumption(cast<CallInst>(&I));
}
}
@@ -1181,71 +1332,6 @@ void CodeExtractor::calculateNewCallTerminatorWeights(
MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
}
-/// Scan the extraction region for lifetime markers which reference inputs.
-/// Erase these markers. Return the inputs which were referenced.
-///
-/// The extraction region is defined by a set of blocks (\p Blocks), and a set
-/// of allocas which will be moved from the caller function into the extracted
-/// function (\p SunkAllocas).
-static SetVector<Value *>
-eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks,
- const SetVector<Value *> &SunkAllocas) {
- SetVector<Value *> InputObjectsWithLifetime;
- for (BasicBlock *BB : Blocks) {
- for (auto It = BB->begin(), End = BB->end(); It != End;) {
- auto *II = dyn_cast<IntrinsicInst>(&*It);
- ++It;
- if (!II || !II->isLifetimeStartOrEnd())
- continue;
-
- // Get the memory operand of the lifetime marker. If the underlying
- // object is a sunk alloca, or is otherwise defined in the extraction
- // region, the lifetime marker must not be erased.
- Value *Mem = II->getOperand(1)->stripInBoundsOffsets();
- if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem))
- continue;
-
- InputObjectsWithLifetime.insert(Mem);
- II->eraseFromParent();
- }
- }
- return InputObjectsWithLifetime;
-}
-
-/// Insert lifetime start/end markers surrounding the call to the new function
-/// for objects defined in the caller.
-static void insertLifetimeMarkersSurroundingCall(
- Module *M, const SetVector<Value *> &InputObjectsWithLifetime,
- CallInst *TheCall) {
- if (InputObjectsWithLifetime.empty())
- return;
-
- LLVMContext &Ctx = M->getContext();
- auto Int8PtrTy = Type::getInt8PtrTy(Ctx);
- auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1);
- auto LifetimeStartFn = llvm::Intrinsic::getDeclaration(
- M, llvm::Intrinsic::lifetime_start, Int8PtrTy);
- auto LifetimeEndFn = llvm::Intrinsic::getDeclaration(
- M, llvm::Intrinsic::lifetime_end, Int8PtrTy);
- for (Value *Mem : InputObjectsWithLifetime) {
- assert((!isa<Instruction>(Mem) ||
- cast<Instruction>(Mem)->getFunction() == TheCall->getFunction()) &&
- "Input memory not defined in original function");
- Value *MemAsI8Ptr = nullptr;
- if (Mem->getType() == Int8PtrTy)
- MemAsI8Ptr = Mem;
- else
- MemAsI8Ptr =
- CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall);
-
- auto StartMarker =
- CallInst::Create(LifetimeStartFn, {NegativeOne, MemAsI8Ptr});
- StartMarker->insertBefore(TheCall);
- auto EndMarker = CallInst::Create(LifetimeEndFn, {NegativeOne, MemAsI8Ptr});
- EndMarker->insertAfter(TheCall);
- }
-}
-
Function *CodeExtractor::extractCodeRegion() {
if (!isEligible())
return nullptr;
@@ -1348,10 +1434,24 @@ Function *CodeExtractor::extractCodeRegion() {
// Find inputs to, outputs from the code region.
findInputsOutputs(inputs, outputs, SinkingCands);
- // Now sink all instructions which only have non-phi uses inside the region
- for (auto *II : SinkingCands)
- cast<Instruction>(II)->moveBefore(*newFuncRoot,
- newFuncRoot->getFirstInsertionPt());
+ // Now sink all instructions which only have non-phi uses inside the region.
+ // Group the allocas at the start of the block, so that any bitcast uses of
+ // the allocas are well-defined.
+ AllocaInst *FirstSunkAlloca = nullptr;
+ for (auto *II : SinkingCands) {
+ if (auto *AI = dyn_cast<AllocaInst>(II)) {
+ AI->moveBefore(*newFuncRoot, newFuncRoot->getFirstInsertionPt());
+ if (!FirstSunkAlloca)
+ FirstSunkAlloca = AI;
+ }
+ }
+ assert((SinkingCands.empty() || FirstSunkAlloca) &&
+ "Did not expect a sink candidate without any allocas");
+ for (auto *II : SinkingCands) {
+ if (!isa<AllocaInst>(II)) {
+ cast<Instruction>(II)->moveAfter(FirstSunkAlloca);
+ }
+ }
if (!HoistingCands.empty()) {
auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
@@ -1361,11 +1461,11 @@ Function *CodeExtractor::extractCodeRegion() {
}
// Collect objects which are inputs to the extraction region and also
- // referenced by lifetime start/end markers within it. The effects of these
+ // referenced by lifetime start markers within it. The effects of these
// markers must be replicated in the calling function to prevent the stack
// coloring pass from merging slots which store input objects.
- ValueSet InputObjectsWithLifetime =
- eraseLifetimeMarkersOnInputs(Blocks, SinkingCands);
+ ValueSet LifetimesStart;
+ eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart);
// Construct new function based on inputs/outputs & add allocas for all defs.
Function *newFunction =
@@ -1388,8 +1488,8 @@ Function *CodeExtractor::extractCodeRegion() {
// Replicate the effects of any lifetime start/end markers which referenced
// input objects in the extraction region by placing markers around the call.
- insertLifetimeMarkersSurroundingCall(oldFunction->getParent(),
- InputObjectsWithLifetime, TheCall);
+ insertLifetimeMarkersSurroundingCall(
+ oldFunction->getParent(), LifetimesStart.getArrayRef(), {}, TheCall);
// Propagate personality info to the new function if there is one.
if (oldFunction->hasPersonalityFn())