aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Coroutines/CoroElide.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Coroutines/CoroElide.cpp')
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroElide.cpp211
1 files changed, 159 insertions, 52 deletions
diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
index 23d22e23861a..9d364b3097c1 100644
--- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -5,12 +5,10 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
-// This pass replaces dynamic allocation of coroutine frame with alloca and
-// replaces calls to llvm.coro.resume and llvm.coro.destroy with direct calls
-// to coroutine sub-functions.
-//===----------------------------------------------------------------------===//
+#include "llvm/Transforms/Coroutines/CoroElide.h"
#include "CoroInternal.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/IR/Dominators.h"
@@ -30,14 +28,19 @@ struct Lowerer : coro::LowererBase {
SmallVector<CoroBeginInst *, 1> CoroBegins;
SmallVector<CoroAllocInst *, 1> CoroAllocs;
SmallVector<CoroSubFnInst *, 4> ResumeAddr;
- SmallVector<CoroSubFnInst *, 4> DestroyAddr;
+ DenseMap<CoroBeginInst *, SmallVector<CoroSubFnInst *, 4>> DestroyAddr;
SmallVector<CoroFreeInst *, 1> CoroFrees;
+ SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches;
Lowerer(Module &M) : LowererBase(M) {}
- void elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA);
+ void elideHeapAllocations(Function *F, uint64_t FrameSize, Align FrameAlign,
+ AAResults &AA);
bool shouldElide(Function *F, DominatorTree &DT) const;
+ void collectPostSplitCoroIds(Function *F);
bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT);
+ bool hasEscapePath(const CoroBeginInst *,
+ const SmallPtrSetImpl<BasicBlock *> &) const;
};
} // end anonymous namespace
@@ -90,10 +93,23 @@ static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) {
}
}
-// Given a resume function @f.resume(%f.frame* %frame), returns %f.frame type.
-static Type *getFrameType(Function *Resume) {
- auto *ArgType = Resume->arg_begin()->getType();
- return cast<PointerType>(ArgType)->getElementType();
+// Given a resume function @f.resume(%f.frame* %frame), returns the size
+// and expected alignment of %f.frame type.
+static std::pair<uint64_t, Align> getFrameLayout(Function *Resume) {
+ // Prefer to pull information from the function attributes.
+ auto Size = Resume->getParamDereferenceableBytes(0);
+ auto Align = Resume->getParamAlign(0);
+
+ // If those aren't given, extract them from the type.
+ if (Size == 0 || !Align) {
+ auto *FrameTy = Resume->arg_begin()->getType()->getPointerElementType();
+
+ const DataLayout &DL = Resume->getParent()->getDataLayout();
+ if (!Size) Size = DL.getTypeAllocSize(FrameTy);
+ if (!Align) Align = DL.getABITypeAlign(FrameTy);
+ }
+
+ return std::make_pair(Size, *Align);
}
// Finds first non alloca instruction in the entry block of a function.
@@ -106,8 +122,9 @@ static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
// To elide heap allocations we need to suppress code blocks guarded by
// llvm.coro.alloc and llvm.coro.free instructions.
-void Lowerer::elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA) {
- LLVMContext &C = FrameTy->getContext();
+void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize,
+ Align FrameAlign, AAResults &AA) {
+ LLVMContext &C = F->getContext();
auto *InsertPt =
getFirstNonAllocaInTheEntryBlock(CoroIds.front()->getFunction());
@@ -128,7 +145,9 @@ void Lowerer::elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA) {
// here. Possibly we will need to do a mini SROA here and break the coroutine
// frame into individual AllocaInst recreating the original alignment.
const DataLayout &DL = F->getParent()->getDataLayout();
+ auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
+ Frame->setAlignment(FrameAlign);
auto *FrameVoidPtr =
new BitCastInst(Frame, Type::getInt8PtrTy(C), "vFrame", InsertPt);
@@ -142,44 +161,92 @@ void Lowerer::elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA) {
removeTailCallAttribute(Frame, AA);
}
+bool Lowerer::hasEscapePath(const CoroBeginInst *CB,
+ const SmallPtrSetImpl<BasicBlock *> &TIs) const {
+ const auto &It = DestroyAddr.find(CB);
+ assert(It != DestroyAddr.end());
+
+ // Limit the number of blocks we visit.
+ unsigned Limit = 32 * (1 + It->second.size());
+
+ SmallVector<const BasicBlock *, 32> Worklist;
+ Worklist.push_back(CB->getParent());
+
+ SmallPtrSet<const BasicBlock *, 32> Visited;
+ // Consider basicblock of coro.destroy as visited one, so that we
+ // skip the path pass through coro.destroy.
+ for (auto *DA : It->second)
+ Visited.insert(DA->getParent());
+
+ do {
+ const auto *BB = Worklist.pop_back_val();
+ if (!Visited.insert(BB).second)
+ continue;
+ if (TIs.count(BB))
+ return true;
+
+ // Conservatively say that there is potentially a path.
+ if (!--Limit)
+ return true;
+
+ auto TI = BB->getTerminator();
+ // Although the default dest of coro.suspend switches is suspend pointer
+ // which means a escape path to normal terminator, it is reasonable to skip
+ // it since coroutine frame doesn't change outside the coroutine body.
+ if (isa<SwitchInst>(TI) &&
+ CoroSuspendSwitches.count(cast<SwitchInst>(TI))) {
+ Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(1));
+ Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(2));
+ } else
+ Worklist.append(succ_begin(BB), succ_end(BB));
+
+ } while (!Worklist.empty());
+
+ // We have exhausted all possible paths and are certain that coro.begin can
+ // not reach to any of terminators.
+ return false;
+}
+
bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
// If no CoroAllocs, we cannot suppress allocation, so elision is not
// possible.
if (CoroAllocs.empty())
return false;
- // Check that for every coro.begin there is a coro.destroy directly
- // referencing the SSA value of that coro.begin along a non-exceptional path.
+ // Check that for every coro.begin there is at least one coro.destroy directly
+ // referencing the SSA value of that coro.begin along each
+ // non-exceptional path.
// If the value escaped, then coro.destroy would have been referencing a
// memory location storing that value and not the virtual register.
+ SmallPtrSet<BasicBlock *, 8> Terminators;
// First gather all of the non-exceptional terminators for the function.
- SmallPtrSet<Instruction *, 8> Terminators;
- for (BasicBlock &B : *F) {
- auto *TI = B.getTerminator();
- if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() &&
- !isa<UnreachableInst>(TI))
- Terminators.insert(TI);
- }
+ // Consider the final coro.suspend as the real terminator when the current
+ // function is a coroutine.
+ for (BasicBlock &B : *F) {
+ auto *TI = B.getTerminator();
+ if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() &&
+ !isa<UnreachableInst>(TI))
+ Terminators.insert(&B);
+ }
// Filter out the coro.destroy that lie along exceptional paths.
- SmallPtrSet<CoroSubFnInst *, 4> DAs;
- for (CoroSubFnInst *DA : DestroyAddr) {
- for (Instruction *TI : Terminators) {
- if (DT.dominates(DA, TI)) {
- DAs.insert(DA);
- break;
+ SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins;
+ for (auto &It : DestroyAddr) {
+ for (Instruction *DA : It.second) {
+ for (BasicBlock *TI : Terminators) {
+ if (DT.dominates(DA, TI->getTerminator())) {
+ ReferencedCoroBegins.insert(It.first);
+ break;
+ }
}
}
- }
- // Find all the coro.begin referenced by coro.destroy along happy paths.
- SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins;
- for (CoroSubFnInst *DA : DAs) {
- if (auto *CB = dyn_cast<CoroBeginInst>(DA->getFrame()))
- ReferencedCoroBegins.insert(CB);
- else
- return false;
+ // Whether there is any paths from coro.begin to Terminators which not pass
+ // through any of the coro.destroys.
+ if (!ReferencedCoroBegins.count(It.first) &&
+ !hasEscapePath(It.first, Terminators))
+ ReferencedCoroBegins.insert(It.first);
}
// If size of the set is the same as total number of coro.begin, that means we
@@ -188,6 +255,30 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
return ReferencedCoroBegins.size() == CoroBegins.size();
}
+void Lowerer::collectPostSplitCoroIds(Function *F) {
+ CoroIds.clear();
+ CoroSuspendSwitches.clear();
+ for (auto &I : instructions(F)) {
+ if (auto *CII = dyn_cast<CoroIdInst>(&I))
+ if (CII->getInfo().isPostSplit())
+ // If it is the coroutine itself, don't touch it.
+ if (CII->getCoroutine() != CII->getFunction())
+ CoroIds.push_back(CII);
+
+ // Consider case like:
+ // %0 = call i8 @llvm.coro.suspend(...)
+ // switch i8 %0, label %suspend [i8 0, label %resume
+ // i8 1, label %cleanup]
+ // and collect the SwitchInsts which are used by escape analysis later.
+ if (auto *CSI = dyn_cast<CoroSuspendInst>(&I))
+ if (CSI->hasOneUse() && isa<SwitchInst>(CSI->use_begin()->getUser())) {
+ SwitchInst *SWI = cast<SwitchInst>(CSI->use_begin()->getUser());
+ if (SWI->getNumCases() == 2)
+ CoroSuspendSwitches.insert(SWI);
+ }
+ }
+}
+
bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
DominatorTree &DT) {
CoroBegins.clear();
@@ -218,7 +309,7 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
ResumeAddr.push_back(II);
break;
case CoroSubFnInst::DestroyIndex:
- DestroyAddr.push_back(II);
+ DestroyAddr[CB].push_back(II);
break;
default:
llvm_unreachable("unexpected coro.subfn.addr constant");
@@ -241,11 +332,13 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
Resumers,
ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex);
- replaceWithConstant(DestroyAddrConstant, DestroyAddr);
+ for (auto &It : DestroyAddr)
+ replaceWithConstant(DestroyAddrConstant, It.second);
if (ShouldElide) {
- auto *FrameTy = getFrameType(cast<Function>(ResumeAddrConstant));
- elideHeapAllocations(CoroId->getFunction(), FrameTy, AA);
+ auto FrameSizeAndAlign = getFrameLayout(cast<Function>(ResumeAddrConstant));
+ elideHeapAllocations(CoroId->getFunction(), FrameSizeAndAlign.first,
+ FrameSizeAndAlign.second, AA);
coro::replaceCoroFree(CoroId, /*Elide=*/true);
}
@@ -272,9 +365,31 @@ static bool replaceDevirtTrigger(Function &F) {
return true;
}
-//===----------------------------------------------------------------------===//
-// Top Level Driver
-//===----------------------------------------------------------------------===//
+static bool declaresCoroElideIntrinsics(Module &M) {
+ return coro::declaresIntrinsics(M, {"llvm.coro.id"});
+}
+
+PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) {
+ auto &M = *F.getParent();
+ if (!declaresCoroElideIntrinsics(M))
+ return PreservedAnalyses::all();
+
+ Lowerer L(M);
+ L.CoroIds.clear();
+ L.collectPostSplitCoroIds(&F);
+ // If we did not find any coro.id, there is nothing to do.
+ if (L.CoroIds.empty())
+ return PreservedAnalyses::all();
+
+ AAResults &AA = AM.getResult<AAManager>(F);
+ DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
+
+ bool Changed = false;
+ for (auto *CII : L.CoroIds)
+ Changed |= L.processCoroId(CII, AA, DT);
+
+ return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
+}
namespace {
struct CoroElideLegacy : FunctionPass {
@@ -286,7 +401,7 @@ struct CoroElideLegacy : FunctionPass {
std::unique_ptr<Lowerer> L;
bool doInitialization(Module &M) override {
- if (coro::declaresIntrinsics(M, {"llvm.coro.id"}))
+ if (declaresCoroElideIntrinsics(M))
L = std::make_unique<Lowerer>(M);
return false;
}
@@ -301,15 +416,7 @@ struct CoroElideLegacy : FunctionPass {
Changed = replaceDevirtTrigger(F);
L->CoroIds.clear();
-
- // Collect all PostSplit coro.ids.
- for (auto &I : instructions(F))
- if (auto *CII = dyn_cast<CoroIdInst>(&I))
- if (CII->getInfo().isPostSplit())
- // If it is the coroutine itself, don't touch it.
- if (CII->getCoroutine() != CII->getFunction())
- L->CoroIds.push_back(CII);
-
+ L->collectPostSplitCoroIds(&F);
// If we did not find any coro.id, there is nothing to do.
if (L->CoroIds.empty())
return Changed;