diff options
Diffstat (limited to 'llvm/lib/Transforms/Coroutines/CoroElide.cpp')
-rw-r--r-- | llvm/lib/Transforms/Coroutines/CoroElide.cpp | 211 |
1 files changed, 159 insertions, 52 deletions
diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp index 23d22e23861a..9d364b3097c1 100644 --- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -5,12 +5,10 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// This pass replaces dynamic allocation of coroutine frame with alloca and -// replaces calls to llvm.coro.resume and llvm.coro.destroy with direct calls -// to coroutine sub-functions. -//===----------------------------------------------------------------------===// +#include "llvm/Transforms/Coroutines/CoroElide.h" #include "CoroInternal.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/Dominators.h" @@ -30,14 +28,19 @@ struct Lowerer : coro::LowererBase { SmallVector<CoroBeginInst *, 1> CoroBegins; SmallVector<CoroAllocInst *, 1> CoroAllocs; SmallVector<CoroSubFnInst *, 4> ResumeAddr; - SmallVector<CoroSubFnInst *, 4> DestroyAddr; + DenseMap<CoroBeginInst *, SmallVector<CoroSubFnInst *, 4>> DestroyAddr; SmallVector<CoroFreeInst *, 1> CoroFrees; + SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches; Lowerer(Module &M) : LowererBase(M) {} - void elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA); + void elideHeapAllocations(Function *F, uint64_t FrameSize, Align FrameAlign, + AAResults &AA); bool shouldElide(Function *F, DominatorTree &DT) const; + void collectPostSplitCoroIds(Function *F); bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT); + bool hasEscapePath(const CoroBeginInst *, + const SmallPtrSetImpl<BasicBlock *> &) const; }; } // end anonymous namespace @@ -90,10 +93,23 @@ static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) { } } -// Given a resume function @f.resume(%f.frame* %frame), returns %f.frame type. -static Type *getFrameType(Function *Resume) { - auto *ArgType = Resume->arg_begin()->getType(); - return cast<PointerType>(ArgType)->getElementType(); +// Given a resume function @f.resume(%f.frame* %frame), returns the size +// and expected alignment of %f.frame type. +static std::pair<uint64_t, Align> getFrameLayout(Function *Resume) { + // Prefer to pull information from the function attributes. + auto Size = Resume->getParamDereferenceableBytes(0); + auto Align = Resume->getParamAlign(0); + + // If those aren't given, extract them from the type. + if (Size == 0 || !Align) { + auto *FrameTy = Resume->arg_begin()->getType()->getPointerElementType(); + + const DataLayout &DL = Resume->getParent()->getDataLayout(); + if (!Size) Size = DL.getTypeAllocSize(FrameTy); + if (!Align) Align = DL.getABITypeAlign(FrameTy); + } + + return std::make_pair(Size, *Align); } // Finds first non alloca instruction in the entry block of a function. @@ -106,8 +122,9 @@ static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) { // To elide heap allocations we need to suppress code blocks guarded by // llvm.coro.alloc and llvm.coro.free instructions. -void Lowerer::elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA) { - LLVMContext &C = FrameTy->getContext(); +void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize, + Align FrameAlign, AAResults &AA) { + LLVMContext &C = F->getContext(); auto *InsertPt = getFirstNonAllocaInTheEntryBlock(CoroIds.front()->getFunction()); @@ -128,7 +145,9 @@ void Lowerer::elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA) { // here. Possibly we will need to do a mini SROA here and break the coroutine // frame into individual AllocaInst recreating the original alignment. const DataLayout &DL = F->getParent()->getDataLayout(); + auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize); auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt); + Frame->setAlignment(FrameAlign); auto *FrameVoidPtr = new BitCastInst(Frame, Type::getInt8PtrTy(C), "vFrame", InsertPt); @@ -142,44 +161,92 @@ void Lowerer::elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA) { removeTailCallAttribute(Frame, AA); } +bool Lowerer::hasEscapePath(const CoroBeginInst *CB, + const SmallPtrSetImpl<BasicBlock *> &TIs) const { + const auto &It = DestroyAddr.find(CB); + assert(It != DestroyAddr.end()); + + // Limit the number of blocks we visit. + unsigned Limit = 32 * (1 + It->second.size()); + + SmallVector<const BasicBlock *, 32> Worklist; + Worklist.push_back(CB->getParent()); + + SmallPtrSet<const BasicBlock *, 32> Visited; + // Consider basicblock of coro.destroy as visited one, so that we + // skip the path pass through coro.destroy. + for (auto *DA : It->second) + Visited.insert(DA->getParent()); + + do { + const auto *BB = Worklist.pop_back_val(); + if (!Visited.insert(BB).second) + continue; + if (TIs.count(BB)) + return true; + + // Conservatively say that there is potentially a path. + if (!--Limit) + return true; + + auto TI = BB->getTerminator(); + // Although the default dest of coro.suspend switches is suspend pointer + // which means a escape path to normal terminator, it is reasonable to skip + // it since coroutine frame doesn't change outside the coroutine body. + if (isa<SwitchInst>(TI) && + CoroSuspendSwitches.count(cast<SwitchInst>(TI))) { + Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(1)); + Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(2)); + } else + Worklist.append(succ_begin(BB), succ_end(BB)); + + } while (!Worklist.empty()); + + // We have exhausted all possible paths and are certain that coro.begin can + // not reach to any of terminators. + return false; +} + bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const { // If no CoroAllocs, we cannot suppress allocation, so elision is not // possible. if (CoroAllocs.empty()) return false; - // Check that for every coro.begin there is a coro.destroy directly - // referencing the SSA value of that coro.begin along a non-exceptional path. + // Check that for every coro.begin there is at least one coro.destroy directly + // referencing the SSA value of that coro.begin along each + // non-exceptional path. // If the value escaped, then coro.destroy would have been referencing a // memory location storing that value and not the virtual register. + SmallPtrSet<BasicBlock *, 8> Terminators; // First gather all of the non-exceptional terminators for the function. - SmallPtrSet<Instruction *, 8> Terminators; - for (BasicBlock &B : *F) { - auto *TI = B.getTerminator(); - if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() && - !isa<UnreachableInst>(TI)) - Terminators.insert(TI); - } + // Consider the final coro.suspend as the real terminator when the current + // function is a coroutine. + for (BasicBlock &B : *F) { + auto *TI = B.getTerminator(); + if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() && + !isa<UnreachableInst>(TI)) + Terminators.insert(&B); + } // Filter out the coro.destroy that lie along exceptional paths. - SmallPtrSet<CoroSubFnInst *, 4> DAs; - for (CoroSubFnInst *DA : DestroyAddr) { - for (Instruction *TI : Terminators) { - if (DT.dominates(DA, TI)) { - DAs.insert(DA); - break; + SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins; + for (auto &It : DestroyAddr) { + for (Instruction *DA : It.second) { + for (BasicBlock *TI : Terminators) { + if (DT.dominates(DA, TI->getTerminator())) { + ReferencedCoroBegins.insert(It.first); + break; + } } } - } - // Find all the coro.begin referenced by coro.destroy along happy paths. - SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins; - for (CoroSubFnInst *DA : DAs) { - if (auto *CB = dyn_cast<CoroBeginInst>(DA->getFrame())) - ReferencedCoroBegins.insert(CB); - else - return false; + // Whether there is any paths from coro.begin to Terminators which not pass + // through any of the coro.destroys. + if (!ReferencedCoroBegins.count(It.first) && + !hasEscapePath(It.first, Terminators)) + ReferencedCoroBegins.insert(It.first); } // If size of the set is the same as total number of coro.begin, that means we @@ -188,6 +255,30 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const { return ReferencedCoroBegins.size() == CoroBegins.size(); } +void Lowerer::collectPostSplitCoroIds(Function *F) { + CoroIds.clear(); + CoroSuspendSwitches.clear(); + for (auto &I : instructions(F)) { + if (auto *CII = dyn_cast<CoroIdInst>(&I)) + if (CII->getInfo().isPostSplit()) + // If it is the coroutine itself, don't touch it. + if (CII->getCoroutine() != CII->getFunction()) + CoroIds.push_back(CII); + + // Consider case like: + // %0 = call i8 @llvm.coro.suspend(...) + // switch i8 %0, label %suspend [i8 0, label %resume + // i8 1, label %cleanup] + // and collect the SwitchInsts which are used by escape analysis later. + if (auto *CSI = dyn_cast<CoroSuspendInst>(&I)) + if (CSI->hasOneUse() && isa<SwitchInst>(CSI->use_begin()->getUser())) { + SwitchInst *SWI = cast<SwitchInst>(CSI->use_begin()->getUser()); + if (SWI->getNumCases() == 2) + CoroSuspendSwitches.insert(SWI); + } + } +} + bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA, DominatorTree &DT) { CoroBegins.clear(); @@ -218,7 +309,7 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA, ResumeAddr.push_back(II); break; case CoroSubFnInst::DestroyIndex: - DestroyAddr.push_back(II); + DestroyAddr[CB].push_back(II); break; default: llvm_unreachable("unexpected coro.subfn.addr constant"); @@ -241,11 +332,13 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA, Resumers, ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex); - replaceWithConstant(DestroyAddrConstant, DestroyAddr); + for (auto &It : DestroyAddr) + replaceWithConstant(DestroyAddrConstant, It.second); if (ShouldElide) { - auto *FrameTy = getFrameType(cast<Function>(ResumeAddrConstant)); - elideHeapAllocations(CoroId->getFunction(), FrameTy, AA); + auto FrameSizeAndAlign = getFrameLayout(cast<Function>(ResumeAddrConstant)); + elideHeapAllocations(CoroId->getFunction(), FrameSizeAndAlign.first, + FrameSizeAndAlign.second, AA); coro::replaceCoroFree(CoroId, /*Elide=*/true); } @@ -272,9 +365,31 @@ static bool replaceDevirtTrigger(Function &F) { return true; } -//===----------------------------------------------------------------------===// -// Top Level Driver -//===----------------------------------------------------------------------===// +static bool declaresCoroElideIntrinsics(Module &M) { + return coro::declaresIntrinsics(M, {"llvm.coro.id"}); +} + +PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) { + auto &M = *F.getParent(); + if (!declaresCoroElideIntrinsics(M)) + return PreservedAnalyses::all(); + + Lowerer L(M); + L.CoroIds.clear(); + L.collectPostSplitCoroIds(&F); + // If we did not find any coro.id, there is nothing to do. + if (L.CoroIds.empty()) + return PreservedAnalyses::all(); + + AAResults &AA = AM.getResult<AAManager>(F); + DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + + bool Changed = false; + for (auto *CII : L.CoroIds) + Changed |= L.processCoroId(CII, AA, DT); + + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} namespace { struct CoroElideLegacy : FunctionPass { @@ -286,7 +401,7 @@ struct CoroElideLegacy : FunctionPass { std::unique_ptr<Lowerer> L; bool doInitialization(Module &M) override { - if (coro::declaresIntrinsics(M, {"llvm.coro.id"})) + if (declaresCoroElideIntrinsics(M)) L = std::make_unique<Lowerer>(M); return false; } @@ -301,15 +416,7 @@ struct CoroElideLegacy : FunctionPass { Changed = replaceDevirtTrigger(F); L->CoroIds.clear(); - - // Collect all PostSplit coro.ids. - for (auto &I : instructions(F)) - if (auto *CII = dyn_cast<CoroIdInst>(&I)) - if (CII->getInfo().isPostSplit()) - // If it is the coroutine itself, don't touch it. - if (CII->getCoroutine() != CII->getFunction()) - L->CoroIds.push_back(CII); - + L->collectPostSplitCoroIds(&F); // If we did not find any coro.id, there is nothing to do. if (L->CoroIds.empty()) return Changed; |