aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Coroutines
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Coroutines')
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroCleanup.cpp139
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroEarly.cpp274
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroElide.cpp343
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroFrame.cpp1440
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroInstr.h515
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroInternal.h243
-rw-r--r--llvm/lib/Transforms/Coroutines/CoroSplit.cpp1601
-rw-r--r--llvm/lib/Transforms/Coroutines/Coroutines.cpp650
8 files changed, 5205 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
new file mode 100644
index 000000000000..c3e05577f044
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp
@@ -0,0 +1,139 @@
+//===- CoroCleanup.cpp - Coroutine Cleanup Pass ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This pass lowers all remaining coroutine intrinsics.
+//===----------------------------------------------------------------------===//
+
+#include "CoroInternal.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/Pass.h"
+#include "llvm/Transforms/Scalar.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "coro-cleanup"
+
+namespace {
+// Created on demand if CoroCleanup pass has work to do.
+struct Lowerer : coro::LowererBase {
+ IRBuilder<> Builder;
+ Lowerer(Module &M) : LowererBase(M), Builder(Context) {}
+ bool lowerRemainingCoroIntrinsics(Function &F);
+};
+}
+
+static void simplifyCFG(Function &F) {
+ llvm::legacy::FunctionPassManager FPM(F.getParent());
+ FPM.add(createCFGSimplificationPass());
+
+ FPM.doInitialization();
+ FPM.run(F);
+ FPM.doFinalization();
+}
+
+static void lowerSubFn(IRBuilder<> &Builder, CoroSubFnInst *SubFn) {
+ Builder.SetInsertPoint(SubFn);
+ Value *FrameRaw = SubFn->getFrame();
+ int Index = SubFn->getIndex();
+
+ auto *FrameTy = StructType::get(
+ SubFn->getContext(), {Builder.getInt8PtrTy(), Builder.getInt8PtrTy()});
+ PointerType *FramePtrTy = FrameTy->getPointerTo();
+
+ Builder.SetInsertPoint(SubFn);
+ auto *FramePtr = Builder.CreateBitCast(FrameRaw, FramePtrTy);
+ auto *Gep = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, Index);
+ auto *Load = Builder.CreateLoad(FrameTy->getElementType(Index), Gep);
+
+ SubFn->replaceAllUsesWith(Load);
+}
+
+bool Lowerer::lowerRemainingCoroIntrinsics(Function &F) {
+ bool Changed = false;
+
+ for (auto IB = inst_begin(F), E = inst_end(F); IB != E;) {
+ Instruction &I = *IB++;
+ if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
+ switch (II->getIntrinsicID()) {
+ default:
+ continue;
+ case Intrinsic::coro_begin:
+ II->replaceAllUsesWith(II->getArgOperand(1));
+ break;
+ case Intrinsic::coro_free:
+ II->replaceAllUsesWith(II->getArgOperand(1));
+ break;
+ case Intrinsic::coro_alloc:
+ II->replaceAllUsesWith(ConstantInt::getTrue(Context));
+ break;
+ case Intrinsic::coro_id:
+ case Intrinsic::coro_id_retcon:
+ case Intrinsic::coro_id_retcon_once:
+ II->replaceAllUsesWith(ConstantTokenNone::get(Context));
+ break;
+ case Intrinsic::coro_subfn_addr:
+ lowerSubFn(Builder, cast<CoroSubFnInst>(II));
+ break;
+ }
+ II->eraseFromParent();
+ Changed = true;
+ }
+ }
+
+ if (Changed) {
+ // After replacement were made we can cleanup the function body a little.
+ simplifyCFG(F);
+ }
+ return Changed;
+}
+
+//===----------------------------------------------------------------------===//
+// Top Level Driver
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct CoroCleanup : FunctionPass {
+ static char ID; // Pass identification, replacement for typeid
+
+ CoroCleanup() : FunctionPass(ID) {
+ initializeCoroCleanupPass(*PassRegistry::getPassRegistry());
+ }
+
+ std::unique_ptr<Lowerer> L;
+
+ // This pass has work to do only if we find intrinsics we are going to lower
+ // in the module.
+ bool doInitialization(Module &M) override {
+ if (coro::declaresIntrinsics(M, {"llvm.coro.alloc", "llvm.coro.begin",
+ "llvm.coro.subfn.addr", "llvm.coro.free",
+ "llvm.coro.id", "llvm.coro.id.retcon",
+ "llvm.coro.id.retcon.once"}))
+ L = std::make_unique<Lowerer>(M);
+ return false;
+ }
+
+ bool runOnFunction(Function &F) override {
+ if (L)
+ return L->lowerRemainingCoroIntrinsics(F);
+ return false;
+ }
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ if (!L)
+ AU.setPreservesAll();
+ }
+ StringRef getPassName() const override { return "Coroutine Cleanup"; }
+};
+}
+
+char CoroCleanup::ID = 0;
+INITIALIZE_PASS(CoroCleanup, "coro-cleanup",
+ "Lower all coroutine related intrinsics", false, false)
+
+Pass *llvm::createCoroCleanupPass() { return new CoroCleanup(); }
diff --git a/llvm/lib/Transforms/Coroutines/CoroEarly.cpp b/llvm/lib/Transforms/Coroutines/CoroEarly.cpp
new file mode 100644
index 000000000000..55993d33ee4e
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/CoroEarly.cpp
@@ -0,0 +1,274 @@
+//===- CoroEarly.cpp - Coroutine Early Function Pass ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This pass lowers coroutine intrinsics that hide the details of the exact
+// calling convention for coroutine resume and destroy functions and details of
+// the structure of the coroutine frame.
+//===----------------------------------------------------------------------===//
+
+#include "CoroInternal.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Pass.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "coro-early"
+
+namespace {
+// Created on demand if CoroEarly pass has work to do.
+class Lowerer : public coro::LowererBase {
+ IRBuilder<> Builder;
+ PointerType *const AnyResumeFnPtrTy;
+ Constant *NoopCoro = nullptr;
+
+ void lowerResumeOrDestroy(CallSite CS, CoroSubFnInst::ResumeKind);
+ void lowerCoroPromise(CoroPromiseInst *Intrin);
+ void lowerCoroDone(IntrinsicInst *II);
+ void lowerCoroNoop(IntrinsicInst *II);
+
+public:
+ Lowerer(Module &M)
+ : LowererBase(M), Builder(Context),
+ AnyResumeFnPtrTy(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
+ /*isVarArg=*/false)
+ ->getPointerTo()) {}
+ bool lowerEarlyIntrinsics(Function &F);
+};
+}
+
+// Replace a direct call to coro.resume or coro.destroy with an indirect call to
+// an address returned by coro.subfn.addr intrinsic. This is done so that
+// CGPassManager recognizes devirtualization when CoroElide pass replaces a call
+// to coro.subfn.addr with an appropriate function address.
+void Lowerer::lowerResumeOrDestroy(CallSite CS,
+ CoroSubFnInst::ResumeKind Index) {
+ Value *ResumeAddr =
+ makeSubFnCall(CS.getArgOperand(0), Index, CS.getInstruction());
+ CS.setCalledFunction(ResumeAddr);
+ CS.setCallingConv(CallingConv::Fast);
+}
+
+// Coroutine promise field is always at the fixed offset from the beginning of
+// the coroutine frame. i8* coro.promise(i8*, i1 from) intrinsic adds an offset
+// to a passed pointer to move from coroutine frame to coroutine promise and
+// vice versa. Since we don't know exactly which coroutine frame it is, we build
+// a coroutine frame mock up starting with two function pointers, followed by a
+// properly aligned coroutine promise field.
+// TODO: Handle the case when coroutine promise alloca has align override.
+void Lowerer::lowerCoroPromise(CoroPromiseInst *Intrin) {
+ Value *Operand = Intrin->getArgOperand(0);
+ unsigned Alignement = Intrin->getAlignment();
+ Type *Int8Ty = Builder.getInt8Ty();
+
+ auto *SampleStruct =
+ StructType::get(Context, {AnyResumeFnPtrTy, AnyResumeFnPtrTy, Int8Ty});
+ const DataLayout &DL = TheModule.getDataLayout();
+ int64_t Offset = alignTo(
+ DL.getStructLayout(SampleStruct)->getElementOffset(2), Alignement);
+ if (Intrin->isFromPromise())
+ Offset = -Offset;
+
+ Builder.SetInsertPoint(Intrin);
+ Value *Replacement =
+ Builder.CreateConstInBoundsGEP1_32(Int8Ty, Operand, Offset);
+
+ Intrin->replaceAllUsesWith(Replacement);
+ Intrin->eraseFromParent();
+}
+
+// When a coroutine reaches final suspend point, it zeros out ResumeFnAddr in
+// the coroutine frame (it is UB to resume from a final suspend point).
+// The llvm.coro.done intrinsic is used to check whether a coroutine is
+// suspended at the final suspend point or not.
+void Lowerer::lowerCoroDone(IntrinsicInst *II) {
+ Value *Operand = II->getArgOperand(0);
+
+ // ResumeFnAddr is the first pointer sized element of the coroutine frame.
+ static_assert(coro::Shape::SwitchFieldIndex::Resume == 0,
+ "resume function not at offset zero");
+ auto *FrameTy = Int8Ptr;
+ PointerType *FramePtrTy = FrameTy->getPointerTo();
+
+ Builder.SetInsertPoint(II);
+ auto *BCI = Builder.CreateBitCast(Operand, FramePtrTy);
+ auto *Load = Builder.CreateLoad(BCI);
+ auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
+
+ II->replaceAllUsesWith(Cond);
+ II->eraseFromParent();
+}
+
+void Lowerer::lowerCoroNoop(IntrinsicInst *II) {
+ if (!NoopCoro) {
+ LLVMContext &C = Builder.getContext();
+ Module &M = *II->getModule();
+
+ // Create a noop.frame struct type.
+ StructType *FrameTy = StructType::create(C, "NoopCoro.Frame");
+ auto *FramePtrTy = FrameTy->getPointerTo();
+ auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
+ /*isVarArg=*/false);
+ auto *FnPtrTy = FnTy->getPointerTo();
+ FrameTy->setBody({FnPtrTy, FnPtrTy});
+
+ // Create a Noop function that does nothing.
+ Function *NoopFn =
+ Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
+ "NoopCoro.ResumeDestroy", &M);
+ NoopFn->setCallingConv(CallingConv::Fast);
+ auto *Entry = BasicBlock::Create(C, "entry", NoopFn);
+ ReturnInst::Create(C, Entry);
+
+ // Create a constant struct for the frame.
+ Constant* Values[] = {NoopFn, NoopFn};
+ Constant* NoopCoroConst = ConstantStruct::get(FrameTy, Values);
+ NoopCoro = new GlobalVariable(M, NoopCoroConst->getType(), /*isConstant=*/true,
+ GlobalVariable::PrivateLinkage, NoopCoroConst,
+ "NoopCoro.Frame.Const");
+ }
+
+ Builder.SetInsertPoint(II);
+ auto *NoopCoroVoidPtr = Builder.CreateBitCast(NoopCoro, Int8Ptr);
+ II->replaceAllUsesWith(NoopCoroVoidPtr);
+ II->eraseFromParent();
+}
+
+// Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate,
+// as CoroSplit assumes there is exactly one coro.begin. After CoroSplit,
+// NoDuplicate attribute will be removed from coro.begin otherwise, it will
+// interfere with inlining.
+static void setCannotDuplicate(CoroIdInst *CoroId) {
+ for (User *U : CoroId->users())
+ if (auto *CB = dyn_cast<CoroBeginInst>(U))
+ CB->setCannotDuplicate();
+}
+
+bool Lowerer::lowerEarlyIntrinsics(Function &F) {
+ bool Changed = false;
+ CoroIdInst *CoroId = nullptr;
+ SmallVector<CoroFreeInst *, 4> CoroFrees;
+ for (auto IB = inst_begin(F), IE = inst_end(F); IB != IE;) {
+ Instruction &I = *IB++;
+ if (auto CS = CallSite(&I)) {
+ switch (CS.getIntrinsicID()) {
+ default:
+ continue;
+ case Intrinsic::coro_free:
+ CoroFrees.push_back(cast<CoroFreeInst>(&I));
+ break;
+ case Intrinsic::coro_suspend:
+ // Make sure that final suspend point is not duplicated as CoroSplit
+ // pass expects that there is at most one final suspend point.
+ if (cast<CoroSuspendInst>(&I)->isFinal())
+ CS.setCannotDuplicate();
+ break;
+ case Intrinsic::coro_end:
+ // Make sure that fallthrough coro.end is not duplicated as CoroSplit
+ // pass expects that there is at most one fallthrough coro.end.
+ if (cast<CoroEndInst>(&I)->isFallthrough())
+ CS.setCannotDuplicate();
+ break;
+ case Intrinsic::coro_noop:
+ lowerCoroNoop(cast<IntrinsicInst>(&I));
+ break;
+ case Intrinsic::coro_id:
+ // Mark a function that comes out of the frontend that has a coro.id
+ // with a coroutine attribute.
+ if (auto *CII = cast<CoroIdInst>(&I)) {
+ if (CII->getInfo().isPreSplit()) {
+ F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT);
+ setCannotDuplicate(CII);
+ CII->setCoroutineSelf();
+ CoroId = cast<CoroIdInst>(&I);
+ }
+ }
+ break;
+ case Intrinsic::coro_id_retcon:
+ case Intrinsic::coro_id_retcon_once:
+ F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
+ break;
+ case Intrinsic::coro_resume:
+ lowerResumeOrDestroy(CS, CoroSubFnInst::ResumeIndex);
+ break;
+ case Intrinsic::coro_destroy:
+ lowerResumeOrDestroy(CS, CoroSubFnInst::DestroyIndex);
+ break;
+ case Intrinsic::coro_promise:
+ lowerCoroPromise(cast<CoroPromiseInst>(&I));
+ break;
+ case Intrinsic::coro_done:
+ lowerCoroDone(cast<IntrinsicInst>(&I));
+ break;
+ }
+ Changed = true;
+ }
+ }
+ // Make sure that all CoroFree reference the coro.id intrinsic.
+ // Token type is not exposed through coroutine C/C++ builtins to plain C, so
+ // we allow specifying none and fixing it up here.
+ if (CoroId)
+ for (CoroFreeInst *CF : CoroFrees)
+ CF->setArgOperand(0, CoroId);
+ return Changed;
+}
+
+//===----------------------------------------------------------------------===//
+// Top Level Driver
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct CoroEarly : public FunctionPass {
+ static char ID; // Pass identification, replacement for typeid.
+ CoroEarly() : FunctionPass(ID) {
+ initializeCoroEarlyPass(*PassRegistry::getPassRegistry());
+ }
+
+ std::unique_ptr<Lowerer> L;
+
+ // This pass has work to do only if we find intrinsics we are going to lower
+ // in the module.
+ bool doInitialization(Module &M) override {
+ if (coro::declaresIntrinsics(M, {"llvm.coro.id",
+ "llvm.coro.id.retcon",
+ "llvm.coro.id.retcon.once",
+ "llvm.coro.destroy",
+ "llvm.coro.done",
+ "llvm.coro.end",
+ "llvm.coro.noop",
+ "llvm.coro.free",
+ "llvm.coro.promise",
+ "llvm.coro.resume",
+ "llvm.coro.suspend"}))
+ L = std::make_unique<Lowerer>(M);
+ return false;
+ }
+
+ bool runOnFunction(Function &F) override {
+ if (!L)
+ return false;
+
+ return L->lowerEarlyIntrinsics(F);
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesCFG();
+ }
+ StringRef getPassName() const override {
+ return "Lower early coroutine intrinsics";
+ }
+};
+}
+
+char CoroEarly::ID = 0;
+INITIALIZE_PASS(CoroEarly, "coro-early", "Lower early coroutine intrinsics",
+ false, false)
+
+Pass *llvm::createCoroEarlyPass() { return new CoroEarly(); }
diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
new file mode 100644
index 000000000000..aca77119023b
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp
@@ -0,0 +1,343 @@
+//===- CoroElide.cpp - Coroutine Frame Allocation Elision Pass ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This pass replaces dynamic allocation of coroutine frame with alloca and
+// replaces calls to llvm.coro.resume and llvm.coro.destroy with direct calls
+// to coroutine sub-functions.
+//===----------------------------------------------------------------------===//
+
+#include "CoroInternal.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "coro-elide"
+
+namespace {
+// Created on demand if CoroElide pass has work to do.
+struct Lowerer : coro::LowererBase {
+ SmallVector<CoroIdInst *, 4> CoroIds;
+ SmallVector<CoroBeginInst *, 1> CoroBegins;
+ SmallVector<CoroAllocInst *, 1> CoroAllocs;
+ SmallVector<CoroSubFnInst *, 4> ResumeAddr;
+ SmallVector<CoroSubFnInst *, 4> DestroyAddr;
+ SmallVector<CoroFreeInst *, 1> CoroFrees;
+
+ Lowerer(Module &M) : LowererBase(M) {}
+
+ void elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA);
+ bool shouldElide(Function *F, DominatorTree &DT) const;
+ bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT);
+};
+} // end anonymous namespace
+
+// Go through the list of coro.subfn.addr intrinsics and replace them with the
+// provided constant.
+static void replaceWithConstant(Constant *Value,
+ SmallVectorImpl<CoroSubFnInst *> &Users) {
+ if (Users.empty())
+ return;
+
+ // See if we need to bitcast the constant to match the type of the intrinsic
+ // being replaced. Note: All coro.subfn.addr intrinsics return the same type,
+ // so we only need to examine the type of the first one in the list.
+ Type *IntrTy = Users.front()->getType();
+ Type *ValueTy = Value->getType();
+ if (ValueTy != IntrTy) {
+ // May need to tweak the function type to match the type expected at the
+ // use site.
+ assert(ValueTy->isPointerTy() && IntrTy->isPointerTy());
+ Value = ConstantExpr::getBitCast(Value, IntrTy);
+ }
+
+ // Now the value type matches the type of the intrinsic. Replace them all!
+ for (CoroSubFnInst *I : Users)
+ replaceAndRecursivelySimplify(I, Value);
+}
+
+// See if any operand of the call instruction references the coroutine frame.
+static bool operandReferences(CallInst *CI, AllocaInst *Frame, AAResults &AA) {
+ for (Value *Op : CI->operand_values())
+ if (AA.alias(Op, Frame) != NoAlias)
+ return true;
+ return false;
+}
+
+// Look for any tail calls referencing the coroutine frame and remove tail
+// attribute from them, since now coroutine frame resides on the stack and tail
+// call implies that the function does not references anything on the stack.
+static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) {
+ Function &F = *Frame->getFunction();
+ for (Instruction &I : instructions(F))
+ if (auto *Call = dyn_cast<CallInst>(&I))
+ if (Call->isTailCall() && operandReferences(Call, Frame, AA)) {
+ // FIXME: If we ever hit this check. Evaluate whether it is more
+ // appropriate to retain musttail and allow the code to compile.
+ if (Call->isMustTailCall())
+ report_fatal_error("Call referring to the coroutine frame cannot be "
+ "marked as musttail");
+ Call->setTailCall(false);
+ }
+}
+
+// Given a resume function @f.resume(%f.frame* %frame), returns %f.frame type.
+static Type *getFrameType(Function *Resume) {
+ auto *ArgType = Resume->arg_begin()->getType();
+ return cast<PointerType>(ArgType)->getElementType();
+}
+
+// Finds first non alloca instruction in the entry block of a function.
+static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
+ for (Instruction &I : F->getEntryBlock())
+ if (!isa<AllocaInst>(&I))
+ return &I;
+ llvm_unreachable("no terminator in the entry block");
+}
+
+// To elide heap allocations we need to suppress code blocks guarded by
+// llvm.coro.alloc and llvm.coro.free instructions.
+void Lowerer::elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA) {
+ LLVMContext &C = FrameTy->getContext();
+ auto *InsertPt =
+ getFirstNonAllocaInTheEntryBlock(CoroIds.front()->getFunction());
+
+ // Replacing llvm.coro.alloc with false will suppress dynamic
+ // allocation as it is expected for the frontend to generate the code that
+ // looks like:
+ // id = coro.id(...)
+ // mem = coro.alloc(id) ? malloc(coro.size()) : 0;
+ // coro.begin(id, mem)
+ auto *False = ConstantInt::getFalse(C);
+ for (auto *CA : CoroAllocs) {
+ CA->replaceAllUsesWith(False);
+ CA->eraseFromParent();
+ }
+
+ // FIXME: Design how to transmit alignment information for every alloca that
+ // is spilled into the coroutine frame and recreate the alignment information
+ // here. Possibly we will need to do a mini SROA here and break the coroutine
+ // frame into individual AllocaInst recreating the original alignment.
+ const DataLayout &DL = F->getParent()->getDataLayout();
+ auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
+ auto *FrameVoidPtr =
+ new BitCastInst(Frame, Type::getInt8PtrTy(C), "vFrame", InsertPt);
+
+ for (auto *CB : CoroBegins) {
+ CB->replaceAllUsesWith(FrameVoidPtr);
+ CB->eraseFromParent();
+ }
+
+ // Since now coroutine frame lives on the stack we need to make sure that
+ // any tail call referencing it, must be made non-tail call.
+ removeTailCallAttribute(Frame, AA);
+}
+
+bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const {
+ // If no CoroAllocs, we cannot suppress allocation, so elision is not
+ // possible.
+ if (CoroAllocs.empty())
+ return false;
+
+ // Check that for every coro.begin there is a coro.destroy directly
+ // referencing the SSA value of that coro.begin along a non-exceptional path.
+ // If the value escaped, then coro.destroy would have been referencing a
+ // memory location storing that value and not the virtual register.
+
+ // First gather all of the non-exceptional terminators for the function.
+ SmallPtrSet<Instruction *, 8> Terminators;
+ for (BasicBlock &B : *F) {
+ auto *TI = B.getTerminator();
+ if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() &&
+ !isa<UnreachableInst>(TI))
+ Terminators.insert(TI);
+ }
+
+ // Filter out the coro.destroy that lie along exceptional paths.
+ SmallPtrSet<CoroSubFnInst *, 4> DAs;
+ for (CoroSubFnInst *DA : DestroyAddr) {
+ for (Instruction *TI : Terminators) {
+ if (DT.dominates(DA, TI)) {
+ DAs.insert(DA);
+ break;
+ }
+ }
+ }
+
+ // Find all the coro.begin referenced by coro.destroy along happy paths.
+ SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins;
+ for (CoroSubFnInst *DA : DAs) {
+ if (auto *CB = dyn_cast<CoroBeginInst>(DA->getFrame()))
+ ReferencedCoroBegins.insert(CB);
+ else
+ return false;
+ }
+
+ // If size of the set is the same as total number of coro.begin, that means we
+ // found a coro.free or coro.destroy referencing each coro.begin, so we can
+ // perform heap elision.
+ return ReferencedCoroBegins.size() == CoroBegins.size();
+}
+
+bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA,
+ DominatorTree &DT) {
+ CoroBegins.clear();
+ CoroAllocs.clear();
+ CoroFrees.clear();
+ ResumeAddr.clear();
+ DestroyAddr.clear();
+
+ // Collect all coro.begin and coro.allocs associated with this coro.id.
+ for (User *U : CoroId->users()) {
+ if (auto *CB = dyn_cast<CoroBeginInst>(U))
+ CoroBegins.push_back(CB);
+ else if (auto *CA = dyn_cast<CoroAllocInst>(U))
+ CoroAllocs.push_back(CA);
+ else if (auto *CF = dyn_cast<CoroFreeInst>(U))
+ CoroFrees.push_back(CF);
+ }
+
+ // Collect all coro.subfn.addrs associated with coro.begin.
+ // Note, we only devirtualize the calls if their coro.subfn.addr refers to
+ // coro.begin directly. If we run into cases where this check is too
+ // conservative, we can consider relaxing the check.
+ for (CoroBeginInst *CB : CoroBegins) {
+ for (User *U : CB->users())
+ if (auto *II = dyn_cast<CoroSubFnInst>(U))
+ switch (II->getIndex()) {
+ case CoroSubFnInst::ResumeIndex:
+ ResumeAddr.push_back(II);
+ break;
+ case CoroSubFnInst::DestroyIndex:
+ DestroyAddr.push_back(II);
+ break;
+ default:
+ llvm_unreachable("unexpected coro.subfn.addr constant");
+ }
+ }
+
+ // PostSplit coro.id refers to an array of subfunctions in its Info
+ // argument.
+ ConstantArray *Resumers = CoroId->getInfo().Resumers;
+ assert(Resumers && "PostSplit coro.id Info argument must refer to an array"
+ "of coroutine subfunctions");
+ auto *ResumeAddrConstant =
+ ConstantExpr::getExtractValue(Resumers, CoroSubFnInst::ResumeIndex);
+
+ replaceWithConstant(ResumeAddrConstant, ResumeAddr);
+
+ bool ShouldElide = shouldElide(CoroId->getFunction(), DT);
+
+ auto *DestroyAddrConstant = ConstantExpr::getExtractValue(
+ Resumers,
+ ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex);
+
+ replaceWithConstant(DestroyAddrConstant, DestroyAddr);
+
+ if (ShouldElide) {
+ auto *FrameTy = getFrameType(cast<Function>(ResumeAddrConstant));
+ elideHeapAllocations(CoroId->getFunction(), FrameTy, AA);
+ coro::replaceCoroFree(CoroId, /*Elide=*/true);
+ }
+
+ return true;
+}
+
+// See if there are any coro.subfn.addr instructions referring to coro.devirt
+// trigger, if so, replace them with a direct call to devirt trigger function.
+static bool replaceDevirtTrigger(Function &F) {
+ SmallVector<CoroSubFnInst *, 1> DevirtAddr;
+ for (auto &I : instructions(F))
+ if (auto *SubFn = dyn_cast<CoroSubFnInst>(&I))
+ if (SubFn->getIndex() == CoroSubFnInst::RestartTrigger)
+ DevirtAddr.push_back(SubFn);
+
+ if (DevirtAddr.empty())
+ return false;
+
+ Module &M = *F.getParent();
+ Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN);
+ assert(DevirtFn && "coro.devirt.fn not found");
+ replaceWithConstant(DevirtFn, DevirtAddr);
+
+ return true;
+}
+
+//===----------------------------------------------------------------------===//
+// Top Level Driver
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct CoroElide : FunctionPass {
+ static char ID;
+ CoroElide() : FunctionPass(ID) {
+ initializeCoroElidePass(*PassRegistry::getPassRegistry());
+ }
+
+ std::unique_ptr<Lowerer> L;
+
+ bool doInitialization(Module &M) override {
+ if (coro::declaresIntrinsics(M, {"llvm.coro.id"}))
+ L = std::make_unique<Lowerer>(M);
+ return false;
+ }
+
+ bool runOnFunction(Function &F) override {
+ if (!L)
+ return false;
+
+ bool Changed = false;
+
+ if (F.hasFnAttribute(CORO_PRESPLIT_ATTR))
+ Changed = replaceDevirtTrigger(F);
+
+ L->CoroIds.clear();
+
+ // Collect all PostSplit coro.ids.
+ for (auto &I : instructions(F))
+ if (auto *CII = dyn_cast<CoroIdInst>(&I))
+ if (CII->getInfo().isPostSplit())
+ // If it is the coroutine itself, don't touch it.
+ if (CII->getCoroutine() != CII->getFunction())
+ L->CoroIds.push_back(CII);
+
+ // If we did not find any coro.id, there is nothing to do.
+ if (L->CoroIds.empty())
+ return Changed;
+
+ AAResults &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
+ DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+
+ for (auto *CII : L->CoroIds)
+ Changed |= L->processCoroId(CII, AA, DT);
+
+ return Changed;
+ }
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<AAResultsWrapperPass>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ }
+ StringRef getPassName() const override { return "Coroutine Elision"; }
+};
+}
+
+char CoroElide::ID = 0;
+INITIALIZE_PASS_BEGIN(
+ CoroElide, "coro-elide",
+ "Coroutine frame allocation elision and indirect calls replacement", false,
+ false)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
+INITIALIZE_PASS_END(
+ CoroElide, "coro-elide",
+ "Coroutine frame allocation elision and indirect calls replacement", false,
+ false)
+
+Pass *llvm::createCoroElidePass() { return new CoroElide(); }
diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
new file mode 100644
index 000000000000..2c42cf8a6d25
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp
@@ -0,0 +1,1440 @@
+//===- CoroFrame.cpp - Builds and manipulates coroutine frame -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This file contains classes used to discover if for a particular value
+// there from sue to definition that crosses a suspend block.
+//
+// Using the information discovered we form a Coroutine Frame structure to
+// contain those values. All uses of those values are replaced with appropriate
+// GEP + load from the coroutine frame. At the point of the definition we spill
+// the value into the coroutine frame.
+//
+// TODO: pack values tightly using liveness info.
+//===----------------------------------------------------------------------===//
+
+#include "CoroInternal.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/Analysis/PtrUseVisitor.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Config/llvm-config.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/circular_raw_ostream.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/PromoteMemToReg.h"
+
+using namespace llvm;
+
+// The "coro-suspend-crossing" flag is very noisy. There is another debug type,
+// "coro-frame", which results in leaner debug spew.
+#define DEBUG_TYPE "coro-suspend-crossing"
+
+enum { SmallVectorThreshold = 32 };
+
+// Provides two way mapping between the blocks and numbers.
+namespace {
+class BlockToIndexMapping {
+ SmallVector<BasicBlock *, SmallVectorThreshold> V;
+
+public:
+ size_t size() const { return V.size(); }
+
+ BlockToIndexMapping(Function &F) {
+ for (BasicBlock &BB : F)
+ V.push_back(&BB);
+ llvm::sort(V);
+ }
+
+ size_t blockToIndex(BasicBlock *BB) const {
+ auto *I = llvm::lower_bound(V, BB);
+ assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block");
+ return I - V.begin();
+ }
+
+ BasicBlock *indexToBlock(unsigned Index) const { return V[Index]; }
+};
+} // end anonymous namespace
+
+// The SuspendCrossingInfo maintains data that allows to answer a question
+// whether given two BasicBlocks A and B there is a path from A to B that
+// passes through a suspend point.
+//
+// For every basic block 'i' it maintains a BlockData that consists of:
+// Consumes: a bit vector which contains a set of indices of blocks that can
+// reach block 'i'
+// Kills: a bit vector which contains a set of indices of blocks that can
+// reach block 'i', but one of the path will cross a suspend point
+// Suspend: a boolean indicating whether block 'i' contains a suspend point.
+// End: a boolean indicating whether block 'i' contains a coro.end intrinsic.
+//
+namespace {
+struct SuspendCrossingInfo {
+ BlockToIndexMapping Mapping;
+
+ struct BlockData {
+ BitVector Consumes;
+ BitVector Kills;
+ bool Suspend = false;
+ bool End = false;
+ };
+ SmallVector<BlockData, SmallVectorThreshold> Block;
+
+ iterator_range<succ_iterator> successors(BlockData const &BD) const {
+ BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
+ return llvm::successors(BB);
+ }
+
+ BlockData &getBlockData(BasicBlock *BB) {
+ return Block[Mapping.blockToIndex(BB)];
+ }
+
+ void dump() const;
+ void dump(StringRef Label, BitVector const &BV) const;
+
+ SuspendCrossingInfo(Function &F, coro::Shape &Shape);
+
+ bool hasPathCrossingSuspendPoint(BasicBlock *DefBB, BasicBlock *UseBB) const {
+ size_t const DefIndex = Mapping.blockToIndex(DefBB);
+ size_t const UseIndex = Mapping.blockToIndex(UseBB);
+
+ assert(Block[UseIndex].Consumes[DefIndex] && "use must consume def");
+ bool const Result = Block[UseIndex].Kills[DefIndex];
+ LLVM_DEBUG(dbgs() << UseBB->getName() << " => " << DefBB->getName()
+ << " answer is " << Result << "\n");
+ return Result;
+ }
+
+ bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const {
+ auto *I = cast<Instruction>(U);
+
+ // We rewrote PHINodes, so that only the ones with exactly one incoming
+ // value need to be analyzed.
+ if (auto *PN = dyn_cast<PHINode>(I))
+ if (PN->getNumIncomingValues() > 1)
+ return false;
+
+ BasicBlock *UseBB = I->getParent();
+
+ // As a special case, treat uses by an llvm.coro.suspend.retcon
+ // as if they were uses in the suspend's single predecessor: the
+ // uses conceptually occur before the suspend.
+ if (isa<CoroSuspendRetconInst>(I)) {
+ UseBB = UseBB->getSinglePredecessor();
+ assert(UseBB && "should have split coro.suspend into its own block");
+ }
+
+ return hasPathCrossingSuspendPoint(DefBB, UseBB);
+ }
+
+ bool isDefinitionAcrossSuspend(Argument &A, User *U) const {
+ return isDefinitionAcrossSuspend(&A.getParent()->getEntryBlock(), U);
+ }
+
+ bool isDefinitionAcrossSuspend(Instruction &I, User *U) const {
+ auto *DefBB = I.getParent();
+
+ // As a special case, treat values produced by an llvm.coro.suspend.*
+ // as if they were defined in the single successor: the uses
+ // conceptually occur after the suspend.
+ if (isa<AnyCoroSuspendInst>(I)) {
+ DefBB = DefBB->getSingleSuccessor();
+ assert(DefBB && "should have split coro.suspend into its own block");
+ }
+
+ return isDefinitionAcrossSuspend(DefBB, U);
+ }
+};
+} // end anonymous namespace
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(StringRef Label,
+ BitVector const &BV) const {
+ dbgs() << Label << ":";
+ for (size_t I = 0, N = BV.size(); I < N; ++I)
+ if (BV[I])
+ dbgs() << " " << Mapping.indexToBlock(I)->getName();
+ dbgs() << "\n";
+}
+
+LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
+ for (size_t I = 0, N = Block.size(); I < N; ++I) {
+ BasicBlock *const B = Mapping.indexToBlock(I);
+ dbgs() << B->getName() << ":\n";
+ dump(" Consumes", Block[I].Consumes);
+ dump(" Kills", Block[I].Kills);
+ }
+ dbgs() << "\n";
+}
+#endif
+
+SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
+ : Mapping(F) {
+ const size_t N = Mapping.size();
+ Block.resize(N);
+
+ // Initialize every block so that it consumes itself
+ for (size_t I = 0; I < N; ++I) {
+ auto &B = Block[I];
+ B.Consumes.resize(N);
+ B.Kills.resize(N);
+ B.Consumes.set(I);
+ }
+
+ // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
+ // the code beyond coro.end is reachable during initial invocation of the
+ // coroutine.
+ for (auto *CE : Shape.CoroEnds)
+ getBlockData(CE->getParent()).End = true;
+
+ // Mark all suspend blocks and indicate that they kill everything they
+ // consume. Note, that crossing coro.save also requires a spill, as any code
+ // between coro.save and coro.suspend may resume the coroutine and all of the
+ // state needs to be saved by that time.
+ auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) {
+ BasicBlock *SuspendBlock = BarrierInst->getParent();
+ auto &B = getBlockData(SuspendBlock);
+ B.Suspend = true;
+ B.Kills |= B.Consumes;
+ };
+ for (auto *CSI : Shape.CoroSuspends) {
+ markSuspendBlock(CSI);
+ if (auto *Save = CSI->getCoroSave())
+ markSuspendBlock(Save);
+ }
+
+ // Iterate propagating consumes and kills until they stop changing.
+ int Iteration = 0;
+ (void)Iteration;
+
+ bool Changed;
+ do {
+ LLVM_DEBUG(dbgs() << "iteration " << ++Iteration);
+ LLVM_DEBUG(dbgs() << "==============\n");
+
+ Changed = false;
+ for (size_t I = 0; I < N; ++I) {
+ auto &B = Block[I];
+ for (BasicBlock *SI : successors(B)) {
+
+ auto SuccNo = Mapping.blockToIndex(SI);
+
+ // Saved Consumes and Kills bitsets so that it is easy to see
+ // if anything changed after propagation.
+ auto &S = Block[SuccNo];
+ auto SavedConsumes = S.Consumes;
+ auto SavedKills = S.Kills;
+
+ // Propagate Kills and Consumes from block B into its successor S.
+ S.Consumes |= B.Consumes;
+ S.Kills |= B.Kills;
+
+ // If block B is a suspend block, it should propagate kills into the
+ // its successor for every block B consumes.
+ if (B.Suspend) {
+ S.Kills |= B.Consumes;
+ }
+ if (S.Suspend) {
+ // If block S is a suspend block, it should kill all of the blocks it
+ // consumes.
+ S.Kills |= S.Consumes;
+ } else if (S.End) {
+ // If block S is an end block, it should not propagate kills as the
+ // blocks following coro.end() are reached during initial invocation
+ // of the coroutine while all the data are still available on the
+ // stack or in the registers.
+ S.Kills.reset();
+ } else {
+ // This is reached when S block it not Suspend nor coro.end and it
+ // need to make sure that it is not in the kill set.
+ S.Kills.reset(SuccNo);
+ }
+
+ // See if anything changed.
+ Changed |= (S.Kills != SavedKills) || (S.Consumes != SavedConsumes);
+
+ if (S.Kills != SavedKills) {
+ LLVM_DEBUG(dbgs() << "\nblock " << I << " follower " << SI->getName()
+ << "\n");
+ LLVM_DEBUG(dump("S.Kills", S.Kills));
+ LLVM_DEBUG(dump("SavedKills", SavedKills));
+ }
+ if (S.Consumes != SavedConsumes) {
+ LLVM_DEBUG(dbgs() << "\nblock " << I << " follower " << SI << "\n");
+ LLVM_DEBUG(dump("S.Consume", S.Consumes));
+ LLVM_DEBUG(dump("SavedCons", SavedConsumes));
+ }
+ }
+ }
+ } while (Changed);
+ LLVM_DEBUG(dump());
+}
+
+#undef DEBUG_TYPE // "coro-suspend-crossing"
+#define DEBUG_TYPE "coro-frame"
+
+// We build up the list of spills for every case where a use is separated
+// from the definition by a suspend point.
+
+static const unsigned InvalidFieldIndex = ~0U;
+
+namespace {
+class Spill {
+ Value *Def = nullptr;
+ Instruction *User = nullptr;
+ unsigned FieldNo = InvalidFieldIndex;
+
+public:
+ Spill(Value *Def, llvm::User *U) : Def(Def), User(cast<Instruction>(U)) {}
+
+ Value *def() const { return Def; }
+ Instruction *user() const { return User; }
+ BasicBlock *userBlock() const { return User->getParent(); }
+
+ // Note that field index is stored in the first SpillEntry for a particular
+ // definition. Subsequent mentions of a defintion do not have fieldNo
+ // assigned. This works out fine as the users of Spills capture the info about
+ // the definition the first time they encounter it. Consider refactoring
+ // SpillInfo into two arrays to normalize the spill representation.
+ unsigned fieldIndex() const {
+ assert(FieldNo != InvalidFieldIndex && "Accessing unassigned field");
+ return FieldNo;
+ }
+ void setFieldIndex(unsigned FieldNumber) {
+ assert(FieldNo == InvalidFieldIndex && "Reassigning field number");
+ FieldNo = FieldNumber;
+ }
+};
+} // namespace
+
+// Note that there may be more than one record with the same value of Def in
+// the SpillInfo vector.
+using SpillInfo = SmallVector<Spill, 8>;
+
+#ifndef NDEBUG
+static void dump(StringRef Title, SpillInfo const &Spills) {
+ dbgs() << "------------- " << Title << "--------------\n";
+ Value *CurrentValue = nullptr;
+ for (auto const &E : Spills) {
+ if (CurrentValue != E.def()) {
+ CurrentValue = E.def();
+ CurrentValue->dump();
+ }
+ dbgs() << " user: ";
+ E.user()->dump();
+ }
+}
+#endif
+
+namespace {
+// We cannot rely solely on natural alignment of a type when building a
+// coroutine frame and if the alignment specified on the Alloca instruction
+// differs from the natural alignment of the alloca type we will need to insert
+// padding.
+struct PaddingCalculator {
+ const DataLayout &DL;
+ LLVMContext &Context;
+ unsigned StructSize = 0;
+
+ PaddingCalculator(LLVMContext &Context, DataLayout const &DL)
+ : DL(DL), Context(Context) {}
+
+ // Replicate the logic from IR/DataLayout.cpp to match field offset
+ // computation for LLVM structs.
+ void addType(Type *Ty) {
+ unsigned TyAlign = DL.getABITypeAlignment(Ty);
+ if ((StructSize & (TyAlign - 1)) != 0)
+ StructSize = alignTo(StructSize, TyAlign);
+
+ StructSize += DL.getTypeAllocSize(Ty); // Consume space for this data item.
+ }
+
+ void addTypes(SmallVectorImpl<Type *> const &Types) {
+ for (auto *Ty : Types)
+ addType(Ty);
+ }
+
+ unsigned computePadding(Type *Ty, unsigned ForcedAlignment) {
+ unsigned TyAlign = DL.getABITypeAlignment(Ty);
+ auto Natural = alignTo(StructSize, TyAlign);
+ auto Forced = alignTo(StructSize, ForcedAlignment);
+
+ // Return how many bytes of padding we need to insert.
+ if (Natural != Forced)
+ return std::max(Natural, Forced) - StructSize;
+
+ // Rely on natural alignment.
+ return 0;
+ }
+
+ // If padding required, return the padding field type to insert.
+ ArrayType *getPaddingType(Type *Ty, unsigned ForcedAlignment) {
+ if (auto Padding = computePadding(Ty, ForcedAlignment))
+ return ArrayType::get(Type::getInt8Ty(Context), Padding);
+
+ return nullptr;
+ }
+};
+} // namespace
+
+// Build a struct that will keep state for an active coroutine.
+// struct f.frame {
+// ResumeFnTy ResumeFnAddr;
+// ResumeFnTy DestroyFnAddr;
+// int ResumeIndex;
+// ... promise (if present) ...
+// ... spills ...
+// };
+static StructType *buildFrameType(Function &F, coro::Shape &Shape,
+ SpillInfo &Spills) {
+ LLVMContext &C = F.getContext();
+ const DataLayout &DL = F.getParent()->getDataLayout();
+ PaddingCalculator Padder(C, DL);
+ SmallString<32> Name(F.getName());
+ Name.append(".Frame");
+ StructType *FrameTy = StructType::create(C, Name);
+ SmallVector<Type *, 8> Types;
+
+ AllocaInst *PromiseAlloca = Shape.getPromiseAlloca();
+
+ if (Shape.ABI == coro::ABI::Switch) {
+ auto *FramePtrTy = FrameTy->getPointerTo();
+ auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
+ /*IsVarArg=*/false);
+ auto *FnPtrTy = FnTy->getPointerTo();
+
+ // Figure out how wide should be an integer type storing the suspend index.
+ unsigned IndexBits = std::max(1U, Log2_64_Ceil(Shape.CoroSuspends.size()));
+ Type *PromiseType = PromiseAlloca
+ ? PromiseAlloca->getType()->getElementType()
+ : Type::getInt1Ty(C);
+ Type *IndexType = Type::getIntNTy(C, IndexBits);
+ Types.push_back(FnPtrTy);
+ Types.push_back(FnPtrTy);
+ Types.push_back(PromiseType);
+ Types.push_back(IndexType);
+ } else {
+ assert(PromiseAlloca == nullptr && "lowering doesn't support promises");
+ }
+
+ Value *CurrentDef = nullptr;
+
+ Padder.addTypes(Types);
+
+ // Create an entry for every spilled value.
+ for (auto &S : Spills) {
+ if (CurrentDef == S.def())
+ continue;
+
+ CurrentDef = S.def();
+ // PromiseAlloca was already added to Types array earlier.
+ if (CurrentDef == PromiseAlloca)
+ continue;
+
+ uint64_t Count = 1;
+ Type *Ty = nullptr;
+ if (auto *AI = dyn_cast<AllocaInst>(CurrentDef)) {
+ Ty = AI->getAllocatedType();
+ if (unsigned AllocaAlignment = AI->getAlignment()) {
+ // If alignment is specified in alloca, see if we need to insert extra
+ // padding.
+ if (auto PaddingTy = Padder.getPaddingType(Ty, AllocaAlignment)) {
+ Types.push_back(PaddingTy);
+ Padder.addType(PaddingTy);
+ }
+ }
+ if (auto *CI = dyn_cast<ConstantInt>(AI->getArraySize()))
+ Count = CI->getValue().getZExtValue();
+ else
+ report_fatal_error("Coroutines cannot handle non static allocas yet");
+ } else {
+ Ty = CurrentDef->getType();
+ }
+ S.setFieldIndex(Types.size());
+ if (Count == 1)
+ Types.push_back(Ty);
+ else
+ Types.push_back(ArrayType::get(Ty, Count));
+ Padder.addType(Ty);
+ }
+ FrameTy->setBody(Types);
+
+ switch (Shape.ABI) {
+ case coro::ABI::Switch:
+ break;
+
+ // Remember whether the frame is inline in the storage.
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce: {
+ auto &Layout = F.getParent()->getDataLayout();
+ auto Id = Shape.getRetconCoroId();
+ Shape.RetconLowering.IsFrameInlineInStorage
+ = (Layout.getTypeAllocSize(FrameTy) <= Id->getStorageSize() &&
+ Layout.getABITypeAlignment(FrameTy) <= Id->getStorageAlignment());
+ break;
+ }
+ }
+
+ return FrameTy;
+}
+
+// We use a pointer use visitor to discover if there are any writes into an
+// alloca that dominates CoroBegin. If that is the case, insertSpills will copy
+// the value from the alloca into the coroutine frame spill slot corresponding
+// to that alloca.
+namespace {
+struct AllocaUseVisitor : PtrUseVisitor<AllocaUseVisitor> {
+ using Base = PtrUseVisitor<AllocaUseVisitor>;
+ AllocaUseVisitor(const DataLayout &DL, const DominatorTree &DT,
+ const CoroBeginInst &CB)
+ : PtrUseVisitor(DL), DT(DT), CoroBegin(CB) {}
+
+ // We are only interested in uses that dominate coro.begin.
+ void visit(Instruction &I) {
+ if (DT.dominates(&I, &CoroBegin))
+ Base::visit(I);
+ }
+ // We need to provide this overload as PtrUseVisitor uses a pointer based
+ // visiting function.
+ void visit(Instruction *I) { return visit(*I); }
+
+ void visitLoadInst(LoadInst &) {} // Good. Nothing to do.
+
+ // If the use is an operand, the pointer escaped and anything can write into
+ // that memory. If the use is the pointer, we are definitely writing into the
+ // alloca and therefore we need to copy.
+ void visitStoreInst(StoreInst &SI) { PI.setAborted(&SI); }
+
+ // Any other instruction that is not filtered out by PtrUseVisitor, will
+ // result in the copy.
+ void visitInstruction(Instruction &I) { PI.setAborted(&I); }
+
+private:
+ const DominatorTree &DT;
+ const CoroBeginInst &CoroBegin;
+};
+} // namespace
+static bool mightWriteIntoAllocaPtr(AllocaInst &A, const DominatorTree &DT,
+ const CoroBeginInst &CB) {
+ const DataLayout &DL = A.getModule()->getDataLayout();
+ AllocaUseVisitor Visitor(DL, DT, CB);
+ auto PtrI = Visitor.visitPtr(A);
+ if (PtrI.isEscaped() || PtrI.isAborted()) {
+ auto *PointerEscapingInstr = PtrI.getEscapingInst()
+ ? PtrI.getEscapingInst()
+ : PtrI.getAbortingInst();
+ if (PointerEscapingInstr) {
+ LLVM_DEBUG(
+ dbgs() << "AllocaInst copy was triggered by instruction: "
+ << *PointerEscapingInstr << "\n");
+ }
+ return true;
+ }
+ return false;
+}
+
+// We need to make room to insert a spill after initial PHIs, but before
+// catchswitch instruction. Placing it before violates the requirement that
+// catchswitch, like all other EHPads must be the first nonPHI in a block.
+//
+// Split away catchswitch into a separate block and insert in its place:
+//
+// cleanuppad <InsertPt> cleanupret.
+//
+// cleanupret instruction will act as an insert point for the spill.
+static Instruction *splitBeforeCatchSwitch(CatchSwitchInst *CatchSwitch) {
+ BasicBlock *CurrentBlock = CatchSwitch->getParent();
+ BasicBlock *NewBlock = CurrentBlock->splitBasicBlock(CatchSwitch);
+ CurrentBlock->getTerminator()->eraseFromParent();
+
+ auto *CleanupPad =
+ CleanupPadInst::Create(CatchSwitch->getParentPad(), {}, "", CurrentBlock);
+ auto *CleanupRet =
+ CleanupReturnInst::Create(CleanupPad, NewBlock, CurrentBlock);
+ return CleanupRet;
+}
+
+// Replace all alloca and SSA values that are accessed across suspend points
+// with GetElementPointer from coroutine frame + loads and stores. Create an
+// AllocaSpillBB that will become the new entry block for the resume parts of
+// the coroutine:
+//
+// %hdl = coro.begin(...)
+// whatever
+//
+// becomes:
+//
+// %hdl = coro.begin(...)
+// %FramePtr = bitcast i8* hdl to %f.frame*
+// br label %AllocaSpillBB
+//
+// AllocaSpillBB:
+// ; geps corresponding to allocas that were moved to coroutine frame
+// br label PostSpill
+//
+// PostSpill:
+// whatever
+//
+//
+static Instruction *insertSpills(const SpillInfo &Spills, coro::Shape &Shape) {
+ auto *CB = Shape.CoroBegin;
+ LLVMContext &C = CB->getContext();
+ IRBuilder<> Builder(CB->getNextNode());
+ StructType *FrameTy = Shape.FrameTy;
+ PointerType *FramePtrTy = FrameTy->getPointerTo();
+ auto *FramePtr =
+ cast<Instruction>(Builder.CreateBitCast(CB, FramePtrTy, "FramePtr"));
+ DominatorTree DT(*CB->getFunction());
+
+ Value *CurrentValue = nullptr;
+ BasicBlock *CurrentBlock = nullptr;
+ Value *CurrentReload = nullptr;
+
+ // Proper field number will be read from field definition.
+ unsigned Index = InvalidFieldIndex;
+
+ // We need to keep track of any allocas that need "spilling"
+ // since they will live in the coroutine frame now, all access to them
+ // need to be changed, not just the access across suspend points
+ // we remember allocas and their indices to be handled once we processed
+ // all the spills.
+ SmallVector<std::pair<AllocaInst *, unsigned>, 4> Allocas;
+ // Promise alloca (if present) has a fixed field number.
+ if (auto *PromiseAlloca = Shape.getPromiseAlloca()) {
+ assert(Shape.ABI == coro::ABI::Switch);
+ Allocas.emplace_back(PromiseAlloca, coro::Shape::SwitchFieldIndex::Promise);
+ }
+
+ // Create a GEP with the given index into the coroutine frame for the original
+ // value Orig. Appends an extra 0 index for array-allocas, preserving the
+ // original type.
+ auto GetFramePointer = [&](uint32_t Index, Value *Orig) -> Value * {
+ SmallVector<Value *, 3> Indices = {
+ ConstantInt::get(Type::getInt32Ty(C), 0),
+ ConstantInt::get(Type::getInt32Ty(C), Index),
+ };
+
+ if (auto *AI = dyn_cast<AllocaInst>(Orig)) {
+ if (auto *CI = dyn_cast<ConstantInt>(AI->getArraySize())) {
+ auto Count = CI->getValue().getZExtValue();
+ if (Count > 1) {
+ Indices.push_back(ConstantInt::get(Type::getInt32Ty(C), 0));
+ }
+ } else {
+ report_fatal_error("Coroutines cannot handle non static allocas yet");
+ }
+ }
+
+ return Builder.CreateInBoundsGEP(FrameTy, FramePtr, Indices);
+ };
+
+ // Create a load instruction to reload the spilled value from the coroutine
+ // frame.
+ auto CreateReload = [&](Instruction *InsertBefore) {
+ assert(Index != InvalidFieldIndex && "accessing unassigned field number");
+ Builder.SetInsertPoint(InsertBefore);
+
+ auto *G = GetFramePointer(Index, CurrentValue);
+ G->setName(CurrentValue->getName() + Twine(".reload.addr"));
+
+ return isa<AllocaInst>(CurrentValue)
+ ? G
+ : Builder.CreateLoad(FrameTy->getElementType(Index), G,
+ CurrentValue->getName() + Twine(".reload"));
+ };
+
+ for (auto const &E : Spills) {
+ // If we have not seen the value, generate a spill.
+ if (CurrentValue != E.def()) {
+ CurrentValue = E.def();
+ CurrentBlock = nullptr;
+ CurrentReload = nullptr;
+
+ Index = E.fieldIndex();
+
+ if (auto *AI = dyn_cast<AllocaInst>(CurrentValue)) {
+ // Spilled AllocaInst will be replaced with GEP from the coroutine frame
+ // there is no spill required.
+ Allocas.emplace_back(AI, Index);
+ if (!AI->isStaticAlloca())
+ report_fatal_error("Coroutines cannot handle non static allocas yet");
+ } else {
+ // Otherwise, create a store instruction storing the value into the
+ // coroutine frame.
+
+ Instruction *InsertPt = nullptr;
+ if (auto Arg = dyn_cast<Argument>(CurrentValue)) {
+ // For arguments, we will place the store instruction right after
+ // the coroutine frame pointer instruction, i.e. bitcast of
+ // coro.begin from i8* to %f.frame*.
+ InsertPt = FramePtr->getNextNode();
+
+ // If we're spilling an Argument, make sure we clear 'nocapture'
+ // from the coroutine function.
+ Arg->getParent()->removeParamAttr(Arg->getArgNo(),
+ Attribute::NoCapture);
+
+ } else if (auto *II = dyn_cast<InvokeInst>(CurrentValue)) {
+ // If we are spilling the result of the invoke instruction, split the
+ // normal edge and insert the spill in the new block.
+ auto NewBB = SplitEdge(II->getParent(), II->getNormalDest());
+ InsertPt = NewBB->getTerminator();
+ } else if (isa<PHINode>(CurrentValue)) {
+ // Skip the PHINodes and EH pads instructions.
+ BasicBlock *DefBlock = cast<Instruction>(E.def())->getParent();
+ if (auto *CSI = dyn_cast<CatchSwitchInst>(DefBlock->getTerminator()))
+ InsertPt = splitBeforeCatchSwitch(CSI);
+ else
+ InsertPt = &*DefBlock->getFirstInsertionPt();
+ } else if (auto CSI = dyn_cast<AnyCoroSuspendInst>(CurrentValue)) {
+ // Don't spill immediately after a suspend; splitting assumes
+ // that the suspend will be followed by a branch.
+ InsertPt = CSI->getParent()->getSingleSuccessor()->getFirstNonPHI();
+ } else {
+ auto *I = cast<Instruction>(E.def());
+ assert(!I->isTerminator() && "unexpected terminator");
+ // For all other values, the spill is placed immediately after
+ // the definition.
+ if (DT.dominates(CB, I)) {
+ InsertPt = I->getNextNode();
+ } else {
+ // Unless, it is not dominated by CoroBegin, then it will be
+ // inserted immediately after CoroFrame is computed.
+ InsertPt = FramePtr->getNextNode();
+ }
+ }
+
+ Builder.SetInsertPoint(InsertPt);
+ auto *G = Builder.CreateConstInBoundsGEP2_32(
+ FrameTy, FramePtr, 0, Index,
+ CurrentValue->getName() + Twine(".spill.addr"));
+ Builder.CreateStore(CurrentValue, G);
+ }
+ }
+
+ // If we have not seen the use block, generate a reload in it.
+ if (CurrentBlock != E.userBlock()) {
+ CurrentBlock = E.userBlock();
+ CurrentReload = CreateReload(&*CurrentBlock->getFirstInsertionPt());
+ }
+
+ // If we have a single edge PHINode, remove it and replace it with a reload
+ // from the coroutine frame. (We already took care of multi edge PHINodes
+ // by rewriting them in the rewritePHIs function).
+ if (auto *PN = dyn_cast<PHINode>(E.user())) {
+ assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming "
+ "values in the PHINode");
+ PN->replaceAllUsesWith(CurrentReload);
+ PN->eraseFromParent();
+ continue;
+ }
+
+ // Replace all uses of CurrentValue in the current instruction with reload.
+ E.user()->replaceUsesOfWith(CurrentValue, CurrentReload);
+ }
+
+ BasicBlock *FramePtrBB = FramePtr->getParent();
+
+ auto SpillBlock =
+ FramePtrBB->splitBasicBlock(FramePtr->getNextNode(), "AllocaSpillBB");
+ SpillBlock->splitBasicBlock(&SpillBlock->front(), "PostSpill");
+ Shape.AllocaSpillBlock = SpillBlock;
+ // If we found any allocas, replace all of their remaining uses with Geps.
+ // Note: we cannot do it indiscriminately as some of the uses may not be
+ // dominated by CoroBegin.
+ bool MightNeedToCopy = false;
+ Builder.SetInsertPoint(&Shape.AllocaSpillBlock->front());
+ SmallVector<Instruction *, 4> UsersToUpdate;
+ for (auto &P : Allocas) {
+ AllocaInst *const A = P.first;
+ UsersToUpdate.clear();
+ for (User *U : A->users()) {
+ auto *I = cast<Instruction>(U);
+ if (DT.dominates(CB, I))
+ UsersToUpdate.push_back(I);
+ else
+ MightNeedToCopy = true;
+ }
+ if (!UsersToUpdate.empty()) {
+ auto *G = GetFramePointer(P.second, A);
+ G->takeName(A);
+ for (Instruction *I : UsersToUpdate)
+ I->replaceUsesOfWith(A, G);
+ }
+ }
+ // If we discovered such uses not dominated by CoroBegin, see if any of them
+ // preceed coro begin and have instructions that can modify the
+ // value of the alloca and therefore would require a copying the value into
+ // the spill slot in the coroutine frame.
+ if (MightNeedToCopy) {
+ Builder.SetInsertPoint(FramePtr->getNextNode());
+
+ for (auto &P : Allocas) {
+ AllocaInst *const A = P.first;
+ if (mightWriteIntoAllocaPtr(*A, DT, *CB)) {
+ if (A->isArrayAllocation())
+ report_fatal_error(
+ "Coroutines cannot handle copying of array allocas yet");
+
+ auto *G = GetFramePointer(P.second, A);
+ auto *Value = Builder.CreateLoad(A);
+ Builder.CreateStore(Value, G);
+ }
+ }
+ }
+ return FramePtr;
+}
+
+// Sets the unwind edge of an instruction to a particular successor.
+static void setUnwindEdgeTo(Instruction *TI, BasicBlock *Succ) {
+ if (auto *II = dyn_cast<InvokeInst>(TI))
+ II->setUnwindDest(Succ);
+ else if (auto *CS = dyn_cast<CatchSwitchInst>(TI))
+ CS->setUnwindDest(Succ);
+ else if (auto *CR = dyn_cast<CleanupReturnInst>(TI))
+ CR->setUnwindDest(Succ);
+ else
+ llvm_unreachable("unexpected terminator instruction");
+}
+
+// Replaces all uses of OldPred with the NewPred block in all PHINodes in a
+// block.
+static void updatePhiNodes(BasicBlock *DestBB, BasicBlock *OldPred,
+ BasicBlock *NewPred,
+ PHINode *LandingPadReplacement) {
+ unsigned BBIdx = 0;
+ for (BasicBlock::iterator I = DestBB->begin(); isa<PHINode>(I); ++I) {
+ PHINode *PN = cast<PHINode>(I);
+
+ // We manually update the LandingPadReplacement PHINode and it is the last
+ // PHI Node. So, if we find it, we are done.
+ if (LandingPadReplacement == PN)
+ break;
+
+ // Reuse the previous value of BBIdx if it lines up. In cases where we
+ // have multiple phi nodes with *lots* of predecessors, this is a speed
+ // win because we don't have to scan the PHI looking for TIBB. This
+ // happens because the BB list of PHI nodes are usually in the same
+ // order.
+ if (PN->getIncomingBlock(BBIdx) != OldPred)
+ BBIdx = PN->getBasicBlockIndex(OldPred);
+
+ assert(BBIdx != (unsigned)-1 && "Invalid PHI Index!");
+ PN->setIncomingBlock(BBIdx, NewPred);
+ }
+}
+
+// Uses SplitEdge unless the successor block is an EHPad, in which case do EH
+// specific handling.
+static BasicBlock *ehAwareSplitEdge(BasicBlock *BB, BasicBlock *Succ,
+ LandingPadInst *OriginalPad,
+ PHINode *LandingPadReplacement) {
+ auto *PadInst = Succ->getFirstNonPHI();
+ if (!LandingPadReplacement && !PadInst->isEHPad())
+ return SplitEdge(BB, Succ);
+
+ auto *NewBB = BasicBlock::Create(BB->getContext(), "", BB->getParent(), Succ);
+ setUnwindEdgeTo(BB->getTerminator(), NewBB);
+ updatePhiNodes(Succ, BB, NewBB, LandingPadReplacement);
+
+ if (LandingPadReplacement) {
+ auto *NewLP = OriginalPad->clone();
+ auto *Terminator = BranchInst::Create(Succ, NewBB);
+ NewLP->insertBefore(Terminator);
+ LandingPadReplacement->addIncoming(NewLP, NewBB);
+ return NewBB;
+ }
+ Value *ParentPad = nullptr;
+ if (auto *FuncletPad = dyn_cast<FuncletPadInst>(PadInst))
+ ParentPad = FuncletPad->getParentPad();
+ else if (auto *CatchSwitch = dyn_cast<CatchSwitchInst>(PadInst))
+ ParentPad = CatchSwitch->getParentPad();
+ else
+ llvm_unreachable("handling for other EHPads not implemented yet");
+
+ auto *NewCleanupPad = CleanupPadInst::Create(ParentPad, {}, "", NewBB);
+ CleanupReturnInst::Create(NewCleanupPad, Succ, NewBB);
+ return NewBB;
+}
+
+static void rewritePHIs(BasicBlock &BB) {
+ // For every incoming edge we will create a block holding all
+ // incoming values in a single PHI nodes.
+ //
+ // loop:
+ // %n.val = phi i32[%n, %entry], [%inc, %loop]
+ //
+ // It will create:
+ //
+ // loop.from.entry:
+ // %n.loop.pre = phi i32 [%n, %entry]
+ // br %label loop
+ // loop.from.loop:
+ // %inc.loop.pre = phi i32 [%inc, %loop]
+ // br %label loop
+ //
+ // After this rewrite, further analysis will ignore any phi nodes with more
+ // than one incoming edge.
+
+ // TODO: Simplify PHINodes in the basic block to remove duplicate
+ // predecessors.
+
+ LandingPadInst *LandingPad = nullptr;
+ PHINode *ReplPHI = nullptr;
+ if ((LandingPad = dyn_cast_or_null<LandingPadInst>(BB.getFirstNonPHI()))) {
+ // ehAwareSplitEdge will clone the LandingPad in all the edge blocks.
+ // We replace the original landing pad with a PHINode that will collect the
+ // results from all of them.
+ ReplPHI = PHINode::Create(LandingPad->getType(), 1, "", LandingPad);
+ ReplPHI->takeName(LandingPad);
+ LandingPad->replaceAllUsesWith(ReplPHI);
+ // We will erase the original landing pad at the end of this function after
+ // ehAwareSplitEdge cloned it in the transition blocks.
+ }
+
+ SmallVector<BasicBlock *, 8> Preds(pred_begin(&BB), pred_end(&BB));
+ for (BasicBlock *Pred : Preds) {
+ auto *IncomingBB = ehAwareSplitEdge(Pred, &BB, LandingPad, ReplPHI);
+ IncomingBB->setName(BB.getName() + Twine(".from.") + Pred->getName());
+ auto *PN = cast<PHINode>(&BB.front());
+ do {
+ int Index = PN->getBasicBlockIndex(IncomingBB);
+ Value *V = PN->getIncomingValue(Index);
+ PHINode *InputV = PHINode::Create(
+ V->getType(), 1, V->getName() + Twine(".") + BB.getName(),
+ &IncomingBB->front());
+ InputV->addIncoming(V, Pred);
+ PN->setIncomingValue(Index, InputV);
+ PN = dyn_cast<PHINode>(PN->getNextNode());
+ } while (PN != ReplPHI); // ReplPHI is either null or the PHI that replaced
+ // the landing pad.
+ }
+
+ if (LandingPad) {
+ // Calls to ehAwareSplitEdge function cloned the original lading pad.
+ // No longer need it.
+ LandingPad->eraseFromParent();
+ }
+}
+
+static void rewritePHIs(Function &F) {
+ SmallVector<BasicBlock *, 8> WorkList;
+
+ for (BasicBlock &BB : F)
+ if (auto *PN = dyn_cast<PHINode>(&BB.front()))
+ if (PN->getNumIncomingValues() > 1)
+ WorkList.push_back(&BB);
+
+ for (BasicBlock *BB : WorkList)
+ rewritePHIs(*BB);
+}
+
+// Check for instructions that we can recreate on resume as opposed to spill
+// the result into a coroutine frame.
+static bool materializable(Instruction &V) {
+ return isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
+ isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V);
+}
+
+// Check for structural coroutine intrinsics that should not be spilled into
+// the coroutine frame.
+static bool isCoroutineStructureIntrinsic(Instruction &I) {
+ return isa<CoroIdInst>(&I) || isa<CoroSaveInst>(&I) ||
+ isa<CoroSuspendInst>(&I);
+}
+
+// For every use of the value that is across suspend point, recreate that value
+// after a suspend point.
+static void rewriteMaterializableInstructions(IRBuilder<> &IRB,
+ SpillInfo const &Spills) {
+ BasicBlock *CurrentBlock = nullptr;
+ Instruction *CurrentMaterialization = nullptr;
+ Instruction *CurrentDef = nullptr;
+
+ for (auto const &E : Spills) {
+ // If it is a new definition, update CurrentXXX variables.
+ if (CurrentDef != E.def()) {
+ CurrentDef = cast<Instruction>(E.def());
+ CurrentBlock = nullptr;
+ CurrentMaterialization = nullptr;
+ }
+
+ // If we have not seen this block, materialize the value.
+ if (CurrentBlock != E.userBlock()) {
+ CurrentBlock = E.userBlock();
+ CurrentMaterialization = cast<Instruction>(CurrentDef)->clone();
+ CurrentMaterialization->setName(CurrentDef->getName());
+ CurrentMaterialization->insertBefore(
+ &*CurrentBlock->getFirstInsertionPt());
+ }
+
+ if (auto *PN = dyn_cast<PHINode>(E.user())) {
+ assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming "
+ "values in the PHINode");
+ PN->replaceAllUsesWith(CurrentMaterialization);
+ PN->eraseFromParent();
+ continue;
+ }
+
+ // Replace all uses of CurrentDef in the current instruction with the
+ // CurrentMaterialization for the block.
+ E.user()->replaceUsesOfWith(CurrentDef, CurrentMaterialization);
+ }
+}
+
+// Splits the block at a particular instruction unless it is the first
+// instruction in the block with a single predecessor.
+static BasicBlock *splitBlockIfNotFirst(Instruction *I, const Twine &Name) {
+ auto *BB = I->getParent();
+ if (&BB->front() == I) {
+ if (BB->getSinglePredecessor()) {
+ BB->setName(Name);
+ return BB;
+ }
+ }
+ return BB->splitBasicBlock(I, Name);
+}
+
+// Split above and below a particular instruction so that it
+// will be all alone by itself in a block.
+static void splitAround(Instruction *I, const Twine &Name) {
+ splitBlockIfNotFirst(I, Name);
+ splitBlockIfNotFirst(I->getNextNode(), "After" + Name);
+}
+
+static bool isSuspendBlock(BasicBlock *BB) {
+ return isa<AnyCoroSuspendInst>(BB->front());
+}
+
+typedef SmallPtrSet<BasicBlock*, 8> VisitedBlocksSet;
+
+/// Does control flow starting at the given block ever reach a suspend
+/// instruction before reaching a block in VisitedOrFreeBBs?
+static bool isSuspendReachableFrom(BasicBlock *From,
+ VisitedBlocksSet &VisitedOrFreeBBs) {
+ // Eagerly try to add this block to the visited set. If it's already
+ // there, stop recursing; this path doesn't reach a suspend before
+ // either looping or reaching a freeing block.
+ if (!VisitedOrFreeBBs.insert(From).second)
+ return false;
+
+ // We assume that we'll already have split suspends into their own blocks.
+ if (isSuspendBlock(From))
+ return true;
+
+ // Recurse on the successors.
+ for (auto Succ : successors(From)) {
+ if (isSuspendReachableFrom(Succ, VisitedOrFreeBBs))
+ return true;
+ }
+
+ return false;
+}
+
+/// Is the given alloca "local", i.e. bounded in lifetime to not cross a
+/// suspend point?
+static bool isLocalAlloca(CoroAllocaAllocInst *AI) {
+ // Seed the visited set with all the basic blocks containing a free
+ // so that we won't pass them up.
+ VisitedBlocksSet VisitedOrFreeBBs;
+ for (auto User : AI->users()) {
+ if (auto FI = dyn_cast<CoroAllocaFreeInst>(User))
+ VisitedOrFreeBBs.insert(FI->getParent());
+ }
+
+ return !isSuspendReachableFrom(AI->getParent(), VisitedOrFreeBBs);
+}
+
+/// After we split the coroutine, will the given basic block be along
+/// an obvious exit path for the resumption function?
+static bool willLeaveFunctionImmediatelyAfter(BasicBlock *BB,
+ unsigned depth = 3) {
+ // If we've bottomed out our depth count, stop searching and assume
+ // that the path might loop back.
+ if (depth == 0) return false;
+
+ // If this is a suspend block, we're about to exit the resumption function.
+ if (isSuspendBlock(BB)) return true;
+
+ // Recurse into the successors.
+ for (auto Succ : successors(BB)) {
+ if (!willLeaveFunctionImmediatelyAfter(Succ, depth - 1))
+ return false;
+ }
+
+ // If none of the successors leads back in a loop, we're on an exit/abort.
+ return true;
+}
+
+static bool localAllocaNeedsStackSave(CoroAllocaAllocInst *AI) {
+ // Look for a free that isn't sufficiently obviously followed by
+ // either a suspend or a termination, i.e. something that will leave
+ // the coro resumption frame.
+ for (auto U : AI->users()) {
+ auto FI = dyn_cast<CoroAllocaFreeInst>(U);
+ if (!FI) continue;
+
+ if (!willLeaveFunctionImmediatelyAfter(FI->getParent()))
+ return true;
+ }
+
+ // If we never found one, we don't need a stack save.
+ return false;
+}
+
+/// Turn each of the given local allocas into a normal (dynamic) alloca
+/// instruction.
+static void lowerLocalAllocas(ArrayRef<CoroAllocaAllocInst*> LocalAllocas,
+ SmallVectorImpl<Instruction*> &DeadInsts) {
+ for (auto AI : LocalAllocas) {
+ auto M = AI->getModule();
+ IRBuilder<> Builder(AI);
+
+ // Save the stack depth. Try to avoid doing this if the stackrestore
+ // is going to immediately precede a return or something.
+ Value *StackSave = nullptr;
+ if (localAllocaNeedsStackSave(AI))
+ StackSave = Builder.CreateCall(
+ Intrinsic::getDeclaration(M, Intrinsic::stacksave));
+
+ // Allocate memory.
+ auto Alloca = Builder.CreateAlloca(Builder.getInt8Ty(), AI->getSize());
+ Alloca->setAlignment(MaybeAlign(AI->getAlignment()));
+
+ for (auto U : AI->users()) {
+ // Replace gets with the allocation.
+ if (isa<CoroAllocaGetInst>(U)) {
+ U->replaceAllUsesWith(Alloca);
+
+ // Replace frees with stackrestores. This is safe because
+ // alloca.alloc is required to obey a stack discipline, although we
+ // don't enforce that structurally.
+ } else {
+ auto FI = cast<CoroAllocaFreeInst>(U);
+ if (StackSave) {
+ Builder.SetInsertPoint(FI);
+ Builder.CreateCall(
+ Intrinsic::getDeclaration(M, Intrinsic::stackrestore),
+ StackSave);
+ }
+ }
+ DeadInsts.push_back(cast<Instruction>(U));
+ }
+
+ DeadInsts.push_back(AI);
+ }
+}
+
+/// Turn the given coro.alloca.alloc call into a dynamic allocation.
+/// This happens during the all-instructions iteration, so it must not
+/// delete the call.
+static Instruction *lowerNonLocalAlloca(CoroAllocaAllocInst *AI,
+ coro::Shape &Shape,
+ SmallVectorImpl<Instruction*> &DeadInsts) {
+ IRBuilder<> Builder(AI);
+ auto Alloc = Shape.emitAlloc(Builder, AI->getSize(), nullptr);
+
+ for (User *U : AI->users()) {
+ if (isa<CoroAllocaGetInst>(U)) {
+ U->replaceAllUsesWith(Alloc);
+ } else {
+ auto FI = cast<CoroAllocaFreeInst>(U);
+ Builder.SetInsertPoint(FI);
+ Shape.emitDealloc(Builder, Alloc, nullptr);
+ }
+ DeadInsts.push_back(cast<Instruction>(U));
+ }
+
+ // Push this on last so that it gets deleted after all the others.
+ DeadInsts.push_back(AI);
+
+ // Return the new allocation value so that we can check for needed spills.
+ return cast<Instruction>(Alloc);
+}
+
+/// Get the current swifterror value.
+static Value *emitGetSwiftErrorValue(IRBuilder<> &Builder, Type *ValueTy,
+ coro::Shape &Shape) {
+ // Make a fake function pointer as a sort of intrinsic.
+ auto FnTy = FunctionType::get(ValueTy, {}, false);
+ auto Fn = ConstantPointerNull::get(FnTy->getPointerTo());
+
+ auto Call = Builder.CreateCall(Fn, {});
+ Shape.SwiftErrorOps.push_back(Call);
+
+ return Call;
+}
+
+/// Set the given value as the current swifterror value.
+///
+/// Returns a slot that can be used as a swifterror slot.
+static Value *emitSetSwiftErrorValue(IRBuilder<> &Builder, Value *V,
+ coro::Shape &Shape) {
+ // Make a fake function pointer as a sort of intrinsic.
+ auto FnTy = FunctionType::get(V->getType()->getPointerTo(),
+ {V->getType()}, false);
+ auto Fn = ConstantPointerNull::get(FnTy->getPointerTo());
+
+ auto Call = Builder.CreateCall(Fn, { V });
+ Shape.SwiftErrorOps.push_back(Call);
+
+ return Call;
+}
+
+/// Set the swifterror value from the given alloca before a call,
+/// then put in back in the alloca afterwards.
+///
+/// Returns an address that will stand in for the swifterror slot
+/// until splitting.
+static Value *emitSetAndGetSwiftErrorValueAround(Instruction *Call,
+ AllocaInst *Alloca,
+ coro::Shape &Shape) {
+ auto ValueTy = Alloca->getAllocatedType();
+ IRBuilder<> Builder(Call);
+
+ // Load the current value from the alloca and set it as the
+ // swifterror value.
+ auto ValueBeforeCall = Builder.CreateLoad(ValueTy, Alloca);
+ auto Addr = emitSetSwiftErrorValue(Builder, ValueBeforeCall, Shape);
+
+ // Move to after the call. Since swifterror only has a guaranteed
+ // value on normal exits, we can ignore implicit and explicit unwind
+ // edges.
+ if (isa<CallInst>(Call)) {
+ Builder.SetInsertPoint(Call->getNextNode());
+ } else {
+ auto Invoke = cast<InvokeInst>(Call);
+ Builder.SetInsertPoint(Invoke->getNormalDest()->getFirstNonPHIOrDbg());
+ }
+
+ // Get the current swifterror value and store it to the alloca.
+ auto ValueAfterCall = emitGetSwiftErrorValue(Builder, ValueTy, Shape);
+ Builder.CreateStore(ValueAfterCall, Alloca);
+
+ return Addr;
+}
+
+/// Eliminate a formerly-swifterror alloca by inserting the get/set
+/// intrinsics and attempting to MemToReg the alloca away.
+static void eliminateSwiftErrorAlloca(Function &F, AllocaInst *Alloca,
+ coro::Shape &Shape) {
+ for (auto UI = Alloca->use_begin(), UE = Alloca->use_end(); UI != UE; ) {
+ // We're likely changing the use list, so use a mutation-safe
+ // iteration pattern.
+ auto &Use = *UI;
+ ++UI;
+
+ // swifterror values can only be used in very specific ways.
+ // We take advantage of that here.
+ auto User = Use.getUser();
+ if (isa<LoadInst>(User) || isa<StoreInst>(User))
+ continue;
+
+ assert(isa<CallInst>(User) || isa<InvokeInst>(User));
+ auto Call = cast<Instruction>(User);
+
+ auto Addr = emitSetAndGetSwiftErrorValueAround(Call, Alloca, Shape);
+
+ // Use the returned slot address as the call argument.
+ Use.set(Addr);
+ }
+
+ // All the uses should be loads and stores now.
+ assert(isAllocaPromotable(Alloca));
+}
+
+/// "Eliminate" a swifterror argument by reducing it to the alloca case
+/// and then loading and storing in the prologue and epilog.
+///
+/// The argument keeps the swifterror flag.
+static void eliminateSwiftErrorArgument(Function &F, Argument &Arg,
+ coro::Shape &Shape,
+ SmallVectorImpl<AllocaInst*> &AllocasToPromote) {
+ IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHIOrDbg());
+
+ auto ArgTy = cast<PointerType>(Arg.getType());
+ auto ValueTy = ArgTy->getElementType();
+
+ // Reduce to the alloca case:
+
+ // Create an alloca and replace all uses of the arg with it.
+ auto Alloca = Builder.CreateAlloca(ValueTy, ArgTy->getAddressSpace());
+ Arg.replaceAllUsesWith(Alloca);
+
+ // Set an initial value in the alloca. swifterror is always null on entry.
+ auto InitialValue = Constant::getNullValue(ValueTy);
+ Builder.CreateStore(InitialValue, Alloca);
+
+ // Find all the suspends in the function and save and restore around them.
+ for (auto Suspend : Shape.CoroSuspends) {
+ (void) emitSetAndGetSwiftErrorValueAround(Suspend, Alloca, Shape);
+ }
+
+ // Find all the coro.ends in the function and restore the error value.
+ for (auto End : Shape.CoroEnds) {
+ Builder.SetInsertPoint(End);
+ auto FinalValue = Builder.CreateLoad(ValueTy, Alloca);
+ (void) emitSetSwiftErrorValue(Builder, FinalValue, Shape);
+ }
+
+ // Now we can use the alloca logic.
+ AllocasToPromote.push_back(Alloca);
+ eliminateSwiftErrorAlloca(F, Alloca, Shape);
+}
+
+/// Eliminate all problematic uses of swifterror arguments and allocas
+/// from the function. We'll fix them up later when splitting the function.
+static void eliminateSwiftError(Function &F, coro::Shape &Shape) {
+ SmallVector<AllocaInst*, 4> AllocasToPromote;
+
+ // Look for a swifterror argument.
+ for (auto &Arg : F.args()) {
+ if (!Arg.hasSwiftErrorAttr()) continue;
+
+ eliminateSwiftErrorArgument(F, Arg, Shape, AllocasToPromote);
+ break;
+ }
+
+ // Look for swifterror allocas.
+ for (auto &Inst : F.getEntryBlock()) {
+ auto Alloca = dyn_cast<AllocaInst>(&Inst);
+ if (!Alloca || !Alloca->isSwiftError()) continue;
+
+ // Clear the swifterror flag.
+ Alloca->setSwiftError(false);
+
+ AllocasToPromote.push_back(Alloca);
+ eliminateSwiftErrorAlloca(F, Alloca, Shape);
+ }
+
+ // If we have any allocas to promote, compute a dominator tree and
+ // promote them en masse.
+ if (!AllocasToPromote.empty()) {
+ DominatorTree DT(F);
+ PromoteMemToReg(AllocasToPromote, DT);
+ }
+}
+
+void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
+ // Lower coro.dbg.declare to coro.dbg.value, since we are going to rewrite
+ // access to local variables.
+ LowerDbgDeclare(F);
+
+ eliminateSwiftError(F, Shape);
+
+ if (Shape.ABI == coro::ABI::Switch &&
+ Shape.SwitchLowering.PromiseAlloca) {
+ Shape.getSwitchCoroId()->clearPromise();
+ }
+
+ // Make sure that all coro.save, coro.suspend and the fallthrough coro.end
+ // intrinsics are in their own blocks to simplify the logic of building up
+ // SuspendCrossing data.
+ for (auto *CSI : Shape.CoroSuspends) {
+ if (auto *Save = CSI->getCoroSave())
+ splitAround(Save, "CoroSave");
+ splitAround(CSI, "CoroSuspend");
+ }
+
+ // Put CoroEnds into their own blocks.
+ for (CoroEndInst *CE : Shape.CoroEnds)
+ splitAround(CE, "CoroEnd");
+
+ // Transforms multi-edge PHI Nodes, so that any value feeding into a PHI will
+ // never has its definition separated from the PHI by the suspend point.
+ rewritePHIs(F);
+
+ // Build suspend crossing info.
+ SuspendCrossingInfo Checker(F, Shape);
+
+ IRBuilder<> Builder(F.getContext());
+ SpillInfo Spills;
+ SmallVector<CoroAllocaAllocInst*, 4> LocalAllocas;
+ SmallVector<Instruction*, 4> DeadInstructions;
+
+ for (int Repeat = 0; Repeat < 4; ++Repeat) {
+ // See if there are materializable instructions across suspend points.
+ for (Instruction &I : instructions(F))
+ if (materializable(I))
+ for (User *U : I.users())
+ if (Checker.isDefinitionAcrossSuspend(I, U))
+ Spills.emplace_back(&I, U);
+
+ if (Spills.empty())
+ break;
+
+ // Rewrite materializable instructions to be materialized at the use point.
+ LLVM_DEBUG(dump("Materializations", Spills));
+ rewriteMaterializableInstructions(Builder, Spills);
+ Spills.clear();
+ }
+
+ // Collect the spills for arguments and other not-materializable values.
+ for (Argument &A : F.args())
+ for (User *U : A.users())
+ if (Checker.isDefinitionAcrossSuspend(A, U))
+ Spills.emplace_back(&A, U);
+
+ for (Instruction &I : instructions(F)) {
+ // Values returned from coroutine structure intrinsics should not be part
+ // of the Coroutine Frame.
+ if (isCoroutineStructureIntrinsic(I) || &I == Shape.CoroBegin)
+ continue;
+
+ // The Coroutine Promise always included into coroutine frame, no need to
+ // check for suspend crossing.
+ if (Shape.ABI == coro::ABI::Switch &&
+ Shape.SwitchLowering.PromiseAlloca == &I)
+ continue;
+
+ // Handle alloca.alloc specially here.
+ if (auto AI = dyn_cast<CoroAllocaAllocInst>(&I)) {
+ // Check whether the alloca's lifetime is bounded by suspend points.
+ if (isLocalAlloca(AI)) {
+ LocalAllocas.push_back(AI);
+ continue;
+ }
+
+ // If not, do a quick rewrite of the alloca and then add spills of
+ // the rewritten value. The rewrite doesn't invalidate anything in
+ // Spills because the other alloca intrinsics have no other operands
+ // besides AI, and it doesn't invalidate the iteration because we delay
+ // erasing AI.
+ auto Alloc = lowerNonLocalAlloca(AI, Shape, DeadInstructions);
+
+ for (User *U : Alloc->users()) {
+ if (Checker.isDefinitionAcrossSuspend(*Alloc, U))
+ Spills.emplace_back(Alloc, U);
+ }
+ continue;
+ }
+
+ // Ignore alloca.get; we process this as part of coro.alloca.alloc.
+ if (isa<CoroAllocaGetInst>(I)) {
+ continue;
+ }
+
+ for (User *U : I.users())
+ if (Checker.isDefinitionAcrossSuspend(I, U)) {
+ // We cannot spill a token.
+ if (I.getType()->isTokenTy())
+ report_fatal_error(
+ "token definition is separated from the use by a suspend point");
+ Spills.emplace_back(&I, U);
+ }
+ }
+ LLVM_DEBUG(dump("Spills", Spills));
+ Shape.FrameTy = buildFrameType(F, Shape, Spills);
+ Shape.FramePtr = insertSpills(Spills, Shape);
+ lowerLocalAllocas(LocalAllocas, DeadInstructions);
+
+ for (auto I : DeadInstructions)
+ I->eraseFromParent();
+}
diff --git a/llvm/lib/Transforms/Coroutines/CoroInstr.h b/llvm/lib/Transforms/Coroutines/CoroInstr.h
new file mode 100644
index 000000000000..de2d2920cb15
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/CoroInstr.h
@@ -0,0 +1,515 @@
+//===-- CoroInstr.h - Coroutine Intrinsics Instruction Wrappers -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This file defines classes that make it really easy to deal with intrinsic
+// functions with the isa/dyncast family of functions. In particular, this
+// allows you to do things like:
+//
+// if (auto *SF = dyn_cast<CoroSubFnInst>(Inst))
+// ... SF->getFrame() ...
+//
+// All intrinsic function calls are instances of the call instruction, so these
+// are all subclasses of the CallInst class. Note that none of these classes
+// has state or virtual methods, which is an important part of this gross/neat
+// hack working.
+//
+// The helpful comment above is borrowed from llvm/IntrinsicInst.h, we keep
+// coroutine intrinsic wrappers here since they are only used by the passes in
+// the Coroutine library.
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TRANSFORMS_COROUTINES_COROINSTR_H
+#define LLVM_LIB_TRANSFORMS_COROUTINES_COROINSTR_H
+
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace llvm {
+
+/// This class represents the llvm.coro.subfn.addr instruction.
+class LLVM_LIBRARY_VISIBILITY CoroSubFnInst : public IntrinsicInst {
+ enum { FrameArg, IndexArg };
+
+public:
+ enum ResumeKind {
+ RestartTrigger = -1,
+ ResumeIndex,
+ DestroyIndex,
+ CleanupIndex,
+ IndexLast,
+ IndexFirst = RestartTrigger
+ };
+
+ Value *getFrame() const { return getArgOperand(FrameArg); }
+ ResumeKind getIndex() const {
+ int64_t Index = getRawIndex()->getValue().getSExtValue();
+ assert(Index >= IndexFirst && Index < IndexLast &&
+ "unexpected CoroSubFnInst index argument");
+ return static_cast<ResumeKind>(Index);
+ }
+
+ ConstantInt *getRawIndex() const {
+ return cast<ConstantInt>(getArgOperand(IndexArg));
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_subfn_addr;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.alloc instruction.
+class LLVM_LIBRARY_VISIBILITY CoroAllocInst : public IntrinsicInst {
+public:
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_alloc;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents a common base class for llvm.coro.id instructions.
+class LLVM_LIBRARY_VISIBILITY AnyCoroIdInst : public IntrinsicInst {
+public:
+ CoroAllocInst *getCoroAlloc() {
+ for (User *U : users())
+ if (auto *CA = dyn_cast<CoroAllocInst>(U))
+ return CA;
+ return nullptr;
+ }
+
+ IntrinsicInst *getCoroBegin() {
+ for (User *U : users())
+ if (auto *II = dyn_cast<IntrinsicInst>(U))
+ if (II->getIntrinsicID() == Intrinsic::coro_begin)
+ return II;
+ llvm_unreachable("no coro.begin associated with coro.id");
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ auto ID = I->getIntrinsicID();
+ return ID == Intrinsic::coro_id ||
+ ID == Intrinsic::coro_id_retcon ||
+ ID == Intrinsic::coro_id_retcon_once;
+ }
+
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.id instruction.
+class LLVM_LIBRARY_VISIBILITY CoroIdInst : public AnyCoroIdInst {
+ enum { AlignArg, PromiseArg, CoroutineArg, InfoArg };
+
+public:
+ AllocaInst *getPromise() const {
+ Value *Arg = getArgOperand(PromiseArg);
+ return isa<ConstantPointerNull>(Arg)
+ ? nullptr
+ : cast<AllocaInst>(Arg->stripPointerCasts());
+ }
+
+ void clearPromise() {
+ Value *Arg = getArgOperand(PromiseArg);
+ setArgOperand(PromiseArg,
+ ConstantPointerNull::get(Type::getInt8PtrTy(getContext())));
+ if (isa<AllocaInst>(Arg))
+ return;
+ assert((isa<BitCastInst>(Arg) || isa<GetElementPtrInst>(Arg)) &&
+ "unexpected instruction designating the promise");
+ // TODO: Add a check that any remaining users of Inst are after coro.begin
+ // or add code to move the users after coro.begin.
+ auto *Inst = cast<Instruction>(Arg);
+ if (Inst->use_empty()) {
+ Inst->eraseFromParent();
+ return;
+ }
+ Inst->moveBefore(getCoroBegin()->getNextNode());
+ }
+
+ // Info argument of coro.id is
+ // fresh out of the frontend: null ;
+ // outlined : {Init, Return, Susp1, Susp2, ...} ;
+ // postsplit : [resume, destroy, cleanup] ;
+ //
+ // If parts of the coroutine were outlined to protect against undesirable
+ // code motion, these functions will be stored in a struct literal referred to
+ // by the Info parameter. Note: this is only needed before coroutine is split.
+ //
+ // After coroutine is split, resume functions are stored in an array
+ // referred to by this parameter.
+
+ struct Info {
+ ConstantStruct *OutlinedParts = nullptr;
+ ConstantArray *Resumers = nullptr;
+
+ bool hasOutlinedParts() const { return OutlinedParts != nullptr; }
+ bool isPostSplit() const { return Resumers != nullptr; }
+ bool isPreSplit() const { return !isPostSplit(); }
+ };
+ Info getInfo() const {
+ Info Result;
+ auto *GV = dyn_cast<GlobalVariable>(getRawInfo());
+ if (!GV)
+ return Result;
+
+ assert(GV->isConstant() && GV->hasDefinitiveInitializer());
+ Constant *Initializer = GV->getInitializer();
+ if ((Result.OutlinedParts = dyn_cast<ConstantStruct>(Initializer)))
+ return Result;
+
+ Result.Resumers = cast<ConstantArray>(Initializer);
+ return Result;
+ }
+ Constant *getRawInfo() const {
+ return cast<Constant>(getArgOperand(InfoArg)->stripPointerCasts());
+ }
+
+ void setInfo(Constant *C) { setArgOperand(InfoArg, C); }
+
+ Function *getCoroutine() const {
+ return cast<Function>(getArgOperand(CoroutineArg)->stripPointerCasts());
+ }
+ void setCoroutineSelf() {
+ assert(isa<ConstantPointerNull>(getArgOperand(CoroutineArg)) &&
+ "Coroutine argument is already assigned");
+ auto *const Int8PtrTy = Type::getInt8PtrTy(getContext());
+ setArgOperand(CoroutineArg,
+ ConstantExpr::getBitCast(getFunction(), Int8PtrTy));
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_id;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents either the llvm.coro.id.retcon or
+/// llvm.coro.id.retcon.once instruction.
+class LLVM_LIBRARY_VISIBILITY AnyCoroIdRetconInst : public AnyCoroIdInst {
+ enum { SizeArg, AlignArg, StorageArg, PrototypeArg, AllocArg, DeallocArg };
+
+public:
+ void checkWellFormed() const;
+
+ uint64_t getStorageSize() const {
+ return cast<ConstantInt>(getArgOperand(SizeArg))->getZExtValue();
+ }
+
+ uint64_t getStorageAlignment() const {
+ return cast<ConstantInt>(getArgOperand(AlignArg))->getZExtValue();
+ }
+
+ Value *getStorage() const {
+ return getArgOperand(StorageArg);
+ }
+
+ /// Return the prototype for the continuation function. The type,
+ /// attributes, and calling convention of the continuation function(s)
+ /// are taken from this declaration.
+ Function *getPrototype() const {
+ return cast<Function>(getArgOperand(PrototypeArg)->stripPointerCasts());
+ }
+
+ /// Return the function to use for allocating memory.
+ Function *getAllocFunction() const {
+ return cast<Function>(getArgOperand(AllocArg)->stripPointerCasts());
+ }
+
+ /// Return the function to use for deallocating memory.
+ Function *getDeallocFunction() const {
+ return cast<Function>(getArgOperand(DeallocArg)->stripPointerCasts());
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ auto ID = I->getIntrinsicID();
+ return ID == Intrinsic::coro_id_retcon
+ || ID == Intrinsic::coro_id_retcon_once;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.id.retcon instruction.
+class LLVM_LIBRARY_VISIBILITY CoroIdRetconInst
+ : public AnyCoroIdRetconInst {
+public:
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_id_retcon;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.id.retcon.once instruction.
+class LLVM_LIBRARY_VISIBILITY CoroIdRetconOnceInst
+ : public AnyCoroIdRetconInst {
+public:
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_id_retcon_once;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.frame instruction.
+class LLVM_LIBRARY_VISIBILITY CoroFrameInst : public IntrinsicInst {
+public:
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_frame;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.free instruction.
+class LLVM_LIBRARY_VISIBILITY CoroFreeInst : public IntrinsicInst {
+ enum { IdArg, FrameArg };
+
+public:
+ Value *getFrame() const { return getArgOperand(FrameArg); }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_free;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This class represents the llvm.coro.begin instruction.
+class LLVM_LIBRARY_VISIBILITY CoroBeginInst : public IntrinsicInst {
+ enum { IdArg, MemArg };
+
+public:
+ AnyCoroIdInst *getId() const {
+ return cast<AnyCoroIdInst>(getArgOperand(IdArg));
+ }
+
+ Value *getMem() const { return getArgOperand(MemArg); }
+
+ // Methods for support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_begin;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.save instruction.
+class LLVM_LIBRARY_VISIBILITY CoroSaveInst : public IntrinsicInst {
+public:
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_save;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.promise instruction.
+class LLVM_LIBRARY_VISIBILITY CoroPromiseInst : public IntrinsicInst {
+ enum { FrameArg, AlignArg, FromArg };
+
+public:
+ bool isFromPromise() const {
+ return cast<Constant>(getArgOperand(FromArg))->isOneValue();
+ }
+ unsigned getAlignment() const {
+ return cast<ConstantInt>(getArgOperand(AlignArg))->getZExtValue();
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_promise;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+class LLVM_LIBRARY_VISIBILITY AnyCoroSuspendInst : public IntrinsicInst {
+public:
+ CoroSaveInst *getCoroSave() const;
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_suspend ||
+ I->getIntrinsicID() == Intrinsic::coro_suspend_retcon;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.suspend instruction.
+class LLVM_LIBRARY_VISIBILITY CoroSuspendInst : public AnyCoroSuspendInst {
+ enum { SaveArg, FinalArg };
+
+public:
+ CoroSaveInst *getCoroSave() const {
+ Value *Arg = getArgOperand(SaveArg);
+ if (auto *SI = dyn_cast<CoroSaveInst>(Arg))
+ return SI;
+ assert(isa<ConstantTokenNone>(Arg));
+ return nullptr;
+ }
+
+ bool isFinal() const {
+ return cast<Constant>(getArgOperand(FinalArg))->isOneValue();
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_suspend;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+inline CoroSaveInst *AnyCoroSuspendInst::getCoroSave() const {
+ if (auto Suspend = dyn_cast<CoroSuspendInst>(this))
+ return Suspend->getCoroSave();
+ return nullptr;
+}
+
+/// This represents the llvm.coro.suspend.retcon instruction.
+class LLVM_LIBRARY_VISIBILITY CoroSuspendRetconInst : public AnyCoroSuspendInst {
+public:
+ op_iterator value_begin() { return arg_begin(); }
+ const_op_iterator value_begin() const { return arg_begin(); }
+
+ op_iterator value_end() { return arg_end(); }
+ const_op_iterator value_end() const { return arg_end(); }
+
+ iterator_range<op_iterator> value_operands() {
+ return make_range(value_begin(), value_end());
+ }
+ iterator_range<const_op_iterator> value_operands() const {
+ return make_range(value_begin(), value_end());
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_suspend_retcon;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.size instruction.
+class LLVM_LIBRARY_VISIBILITY CoroSizeInst : public IntrinsicInst {
+public:
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_size;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.end instruction.
+class LLVM_LIBRARY_VISIBILITY CoroEndInst : public IntrinsicInst {
+ enum { FrameArg, UnwindArg };
+
+public:
+ bool isFallthrough() const { return !isUnwind(); }
+ bool isUnwind() const {
+ return cast<Constant>(getArgOperand(UnwindArg))->isOneValue();
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_end;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.alloca.alloc instruction.
+class LLVM_LIBRARY_VISIBILITY CoroAllocaAllocInst : public IntrinsicInst {
+ enum { SizeArg, AlignArg };
+public:
+ Value *getSize() const {
+ return getArgOperand(SizeArg);
+ }
+ unsigned getAlignment() const {
+ return cast<ConstantInt>(getArgOperand(AlignArg))->getZExtValue();
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_alloca_alloc;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.alloca.get instruction.
+class LLVM_LIBRARY_VISIBILITY CoroAllocaGetInst : public IntrinsicInst {
+ enum { AllocArg };
+public:
+ CoroAllocaAllocInst *getAlloc() const {
+ return cast<CoroAllocaAllocInst>(getArgOperand(AllocArg));
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_alloca_get;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+/// This represents the llvm.coro.alloca.free instruction.
+class LLVM_LIBRARY_VISIBILITY CoroAllocaFreeInst : public IntrinsicInst {
+ enum { AllocArg };
+public:
+ CoroAllocaAllocInst *getAlloc() const {
+ return cast<CoroAllocaAllocInst>(getArgOperand(AllocArg));
+ }
+
+ // Methods to support type inquiry through isa, cast, and dyn_cast:
+ static bool classof(const IntrinsicInst *I) {
+ return I->getIntrinsicID() == Intrinsic::coro_alloca_free;
+ }
+ static bool classof(const Value *V) {
+ return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+ }
+};
+
+} // End namespace llvm.
+
+#endif
diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h
new file mode 100644
index 000000000000..c151474316f9
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h
@@ -0,0 +1,243 @@
+//===- CoroInternal.h - Internal Coroutine interfaces ---------*- C++ -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Common definitions/declarations used internally by coroutine lowering passes.
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H
+#define LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H
+
+#include "CoroInstr.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Transforms/Coroutines.h"
+
+namespace llvm {
+
+class CallGraph;
+class CallGraphSCC;
+class PassRegistry;
+
+void initializeCoroEarlyPass(PassRegistry &);
+void initializeCoroSplitPass(PassRegistry &);
+void initializeCoroElidePass(PassRegistry &);
+void initializeCoroCleanupPass(PassRegistry &);
+
+// CoroEarly pass marks every function that has coro.begin with a string
+// attribute "coroutine.presplit"="0". CoroSplit pass processes the coroutine
+// twice. First, it lets it go through complete IPO optimization pipeline as a
+// single function. It forces restart of the pipeline by inserting an indirect
+// call to an empty function "coro.devirt.trigger" which is devirtualized by
+// CoroElide pass that triggers a restart of the pipeline by CGPassManager.
+// When CoroSplit pass sees the same coroutine the second time, it splits it up,
+// adds coroutine subfunctions to the SCC to be processed by IPO pipeline.
+
+#define CORO_PRESPLIT_ATTR "coroutine.presplit"
+#define UNPREPARED_FOR_SPLIT "0"
+#define PREPARED_FOR_SPLIT "1"
+
+#define CORO_DEVIRT_TRIGGER_FN "coro.devirt.trigger"
+
+namespace coro {
+
+bool declaresIntrinsics(Module &M, std::initializer_list<StringRef>);
+void replaceAllCoroAllocs(CoroBeginInst *CB, bool Replacement);
+void replaceAllCoroFrees(CoroBeginInst *CB, Value *Replacement);
+void replaceCoroFree(CoroIdInst *CoroId, bool Elide);
+void updateCallGraph(Function &Caller, ArrayRef<Function *> Funcs,
+ CallGraph &CG, CallGraphSCC &SCC);
+
+// Keeps data and helper functions for lowering coroutine intrinsics.
+struct LowererBase {
+ Module &TheModule;
+ LLVMContext &Context;
+ PointerType *const Int8Ptr;
+ FunctionType *const ResumeFnType;
+ ConstantPointerNull *const NullPtr;
+
+ LowererBase(Module &M);
+ Value *makeSubFnCall(Value *Arg, int Index, Instruction *InsertPt);
+};
+
+enum class ABI {
+ /// The "resume-switch" lowering, where there are separate resume and
+ /// destroy functions that are shared between all suspend points. The
+ /// coroutine frame implicitly stores the resume and destroy functions,
+ /// the current index, and any promise value.
+ Switch,
+
+ /// The "returned-continuation" lowering, where each suspend point creates a
+ /// single continuation function that is used for both resuming and
+ /// destroying. Does not support promises.
+ Retcon,
+
+ /// The "unique returned-continuation" lowering, where each suspend point
+ /// creates a single continuation function that is used for both resuming
+ /// and destroying. Does not support promises. The function is known to
+ /// suspend at most once during its execution, and the return value of
+ /// the continuation is void.
+ RetconOnce,
+};
+
+// Holds structural Coroutine Intrinsics for a particular function and other
+// values used during CoroSplit pass.
+struct LLVM_LIBRARY_VISIBILITY Shape {
+ CoroBeginInst *CoroBegin;
+ SmallVector<CoroEndInst *, 4> CoroEnds;
+ SmallVector<CoroSizeInst *, 2> CoroSizes;
+ SmallVector<AnyCoroSuspendInst *, 4> CoroSuspends;
+ SmallVector<CallInst*, 2> SwiftErrorOps;
+
+ // Field indexes for special fields in the switch lowering.
+ struct SwitchFieldIndex {
+ enum {
+ Resume,
+ Destroy,
+ Promise,
+ Index,
+ /// The index of the first spill field.
+ FirstSpill
+ };
+ };
+
+ coro::ABI ABI;
+
+ StructType *FrameTy;
+ Instruction *FramePtr;
+ BasicBlock *AllocaSpillBlock;
+
+ struct SwitchLoweringStorage {
+ SwitchInst *ResumeSwitch;
+ AllocaInst *PromiseAlloca;
+ BasicBlock *ResumeEntryBlock;
+ bool HasFinalSuspend;
+ };
+
+ struct RetconLoweringStorage {
+ Function *ResumePrototype;
+ Function *Alloc;
+ Function *Dealloc;
+ BasicBlock *ReturnBlock;
+ bool IsFrameInlineInStorage;
+ };
+
+ union {
+ SwitchLoweringStorage SwitchLowering;
+ RetconLoweringStorage RetconLowering;
+ };
+
+ CoroIdInst *getSwitchCoroId() const {
+ assert(ABI == coro::ABI::Switch);
+ return cast<CoroIdInst>(CoroBegin->getId());
+ }
+
+ AnyCoroIdRetconInst *getRetconCoroId() const {
+ assert(ABI == coro::ABI::Retcon ||
+ ABI == coro::ABI::RetconOnce);
+ return cast<AnyCoroIdRetconInst>(CoroBegin->getId());
+ }
+
+ IntegerType *getIndexType() const {
+ assert(ABI == coro::ABI::Switch);
+ assert(FrameTy && "frame type not assigned");
+ return cast<IntegerType>(FrameTy->getElementType(SwitchFieldIndex::Index));
+ }
+ ConstantInt *getIndex(uint64_t Value) const {
+ return ConstantInt::get(getIndexType(), Value);
+ }
+
+ PointerType *getSwitchResumePointerType() const {
+ assert(ABI == coro::ABI::Switch);
+ assert(FrameTy && "frame type not assigned");
+ return cast<PointerType>(FrameTy->getElementType(SwitchFieldIndex::Resume));
+ }
+
+ FunctionType *getResumeFunctionType() const {
+ switch (ABI) {
+ case coro::ABI::Switch: {
+ auto *FnPtrTy = getSwitchResumePointerType();
+ return cast<FunctionType>(FnPtrTy->getPointerElementType());
+ }
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ return RetconLowering.ResumePrototype->getFunctionType();
+ }
+ llvm_unreachable("Unknown coro::ABI enum");
+ }
+
+ ArrayRef<Type*> getRetconResultTypes() const {
+ assert(ABI == coro::ABI::Retcon ||
+ ABI == coro::ABI::RetconOnce);
+ auto FTy = CoroBegin->getFunction()->getFunctionType();
+
+ // The safety of all this is checked by checkWFRetconPrototype.
+ if (auto STy = dyn_cast<StructType>(FTy->getReturnType())) {
+ return STy->elements().slice(1);
+ } else {
+ return ArrayRef<Type*>();
+ }
+ }
+
+ ArrayRef<Type*> getRetconResumeTypes() const {
+ assert(ABI == coro::ABI::Retcon ||
+ ABI == coro::ABI::RetconOnce);
+
+ // The safety of all this is checked by checkWFRetconPrototype.
+ auto FTy = RetconLowering.ResumePrototype->getFunctionType();
+ return FTy->params().slice(1);
+ }
+
+ CallingConv::ID getResumeFunctionCC() const {
+ switch (ABI) {
+ case coro::ABI::Switch:
+ return CallingConv::Fast;
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ return RetconLowering.ResumePrototype->getCallingConv();
+ }
+ llvm_unreachable("Unknown coro::ABI enum");
+ }
+
+ unsigned getFirstSpillFieldIndex() const {
+ switch (ABI) {
+ case coro::ABI::Switch:
+ return SwitchFieldIndex::FirstSpill;
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ return 0;
+ }
+ llvm_unreachable("Unknown coro::ABI enum");
+ }
+
+ AllocaInst *getPromiseAlloca() const {
+ if (ABI == coro::ABI::Switch)
+ return SwitchLowering.PromiseAlloca;
+ return nullptr;
+ }
+
+ /// Allocate memory according to the rules of the active lowering.
+ ///
+ /// \param CG - if non-null, will be updated for the new call
+ Value *emitAlloc(IRBuilder<> &Builder, Value *Size, CallGraph *CG) const;
+
+ /// Deallocate memory according to the rules of the active lowering.
+ ///
+ /// \param CG - if non-null, will be updated for the new call
+ void emitDealloc(IRBuilder<> &Builder, Value *Ptr, CallGraph *CG) const;
+
+ Shape() = default;
+ explicit Shape(Function &F) { buildFrom(F); }
+ void buildFrom(Function &F);
+};
+
+void buildCoroutineFrame(Function &F, Shape &Shape);
+
+} // End namespace coro.
+} // End namespace llvm
+
+#endif
diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
new file mode 100644
index 000000000000..04723cbde417
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp
@@ -0,0 +1,1601 @@
+//===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This pass builds the coroutine frame and outlines resume and destroy parts
+// of the coroutine into separate functions.
+//
+// We present a coroutine to an LLVM as an ordinary function with suspension
+// points marked up with intrinsics. We let the optimizer party on the coroutine
+// as a single function for as long as possible. Shortly before the coroutine is
+// eligible to be inlined into its callers, we split up the coroutine into parts
+// corresponding to an initial, resume and destroy invocations of the coroutine,
+// add them to the current SCC and restart the IPO pipeline to optimize the
+// coroutine subfunctions we extracted before proceeding to the caller of the
+// coroutine.
+//===----------------------------------------------------------------------===//
+
+#include "CoroInstr.h"
+#include "CoroInternal.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/CallGraph.h"
+#include "llvm/Analysis/CallGraphSCCPass.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/CallingConv.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Value.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/ValueMapper.h"
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <initializer_list>
+#include <iterator>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "coro-split"
+
+namespace {
+
+/// A little helper class for building
+class CoroCloner {
+public:
+ enum class Kind {
+ /// The shared resume function for a switch lowering.
+ SwitchResume,
+
+ /// The shared unwind function for a switch lowering.
+ SwitchUnwind,
+
+ /// The shared cleanup function for a switch lowering.
+ SwitchCleanup,
+
+ /// An individual continuation function.
+ Continuation,
+ };
+private:
+ Function &OrigF;
+ Function *NewF;
+ const Twine &Suffix;
+ coro::Shape &Shape;
+ Kind FKind;
+ ValueToValueMapTy VMap;
+ IRBuilder<> Builder;
+ Value *NewFramePtr = nullptr;
+ Value *SwiftErrorSlot = nullptr;
+
+ /// The active suspend instruction; meaningful only for continuation ABIs.
+ AnyCoroSuspendInst *ActiveSuspend = nullptr;
+
+public:
+ /// Create a cloner for a switch lowering.
+ CoroCloner(Function &OrigF, const Twine &Suffix, coro::Shape &Shape,
+ Kind FKind)
+ : OrigF(OrigF), NewF(nullptr), Suffix(Suffix), Shape(Shape),
+ FKind(FKind), Builder(OrigF.getContext()) {
+ assert(Shape.ABI == coro::ABI::Switch);
+ }
+
+ /// Create a cloner for a continuation lowering.
+ CoroCloner(Function &OrigF, const Twine &Suffix, coro::Shape &Shape,
+ Function *NewF, AnyCoroSuspendInst *ActiveSuspend)
+ : OrigF(OrigF), NewF(NewF), Suffix(Suffix), Shape(Shape),
+ FKind(Kind::Continuation), Builder(OrigF.getContext()),
+ ActiveSuspend(ActiveSuspend) {
+ assert(Shape.ABI == coro::ABI::Retcon ||
+ Shape.ABI == coro::ABI::RetconOnce);
+ assert(NewF && "need existing function for continuation");
+ assert(ActiveSuspend && "need active suspend point for continuation");
+ }
+
+ Function *getFunction() const {
+ assert(NewF != nullptr && "declaration not yet set");
+ return NewF;
+ }
+
+ void create();
+
+private:
+ bool isSwitchDestroyFunction() {
+ switch (FKind) {
+ case Kind::Continuation:
+ case Kind::SwitchResume:
+ return false;
+ case Kind::SwitchUnwind:
+ case Kind::SwitchCleanup:
+ return true;
+ }
+ llvm_unreachable("Unknown CoroCloner::Kind enum");
+ }
+
+ void createDeclaration();
+ void replaceEntryBlock();
+ Value *deriveNewFramePointer();
+ void replaceRetconSuspendUses();
+ void replaceCoroSuspends();
+ void replaceCoroEnds();
+ void replaceSwiftErrorOps();
+ void handleFinalSuspend();
+ void maybeFreeContinuationStorage();
+};
+
+} // end anonymous namespace
+
+static void maybeFreeRetconStorage(IRBuilder<> &Builder, coro::Shape &Shape,
+ Value *FramePtr, CallGraph *CG) {
+ assert(Shape.ABI == coro::ABI::Retcon ||
+ Shape.ABI == coro::ABI::RetconOnce);
+ if (Shape.RetconLowering.IsFrameInlineInStorage)
+ return;
+
+ Shape.emitDealloc(Builder, FramePtr, CG);
+}
+
+/// Replace a non-unwind call to llvm.coro.end.
+static void replaceFallthroughCoroEnd(CoroEndInst *End, coro::Shape &Shape,
+ Value *FramePtr, bool InResume,
+ CallGraph *CG) {
+ // Start inserting right before the coro.end.
+ IRBuilder<> Builder(End);
+
+ // Create the return instruction.
+ switch (Shape.ABI) {
+ // The cloned functions in switch-lowering always return void.
+ case coro::ABI::Switch:
+ // coro.end doesn't immediately end the coroutine in the main function
+ // in this lowering, because we need to deallocate the coroutine.
+ if (!InResume)
+ return;
+ Builder.CreateRetVoid();
+ break;
+
+ // In unique continuation lowering, the continuations always return void.
+ // But we may have implicitly allocated storage.
+ case coro::ABI::RetconOnce:
+ maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
+ Builder.CreateRetVoid();
+ break;
+
+ // In non-unique continuation lowering, we signal completion by returning
+ // a null continuation.
+ case coro::ABI::Retcon: {
+ maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
+ auto RetTy = Shape.getResumeFunctionType()->getReturnType();
+ auto RetStructTy = dyn_cast<StructType>(RetTy);
+ PointerType *ContinuationTy =
+ cast<PointerType>(RetStructTy ? RetStructTy->getElementType(0) : RetTy);
+
+ Value *ReturnValue = ConstantPointerNull::get(ContinuationTy);
+ if (RetStructTy) {
+ ReturnValue = Builder.CreateInsertValue(UndefValue::get(RetStructTy),
+ ReturnValue, 0);
+ }
+ Builder.CreateRet(ReturnValue);
+ break;
+ }
+ }
+
+ // Remove the rest of the block, by splitting it into an unreachable block.
+ auto *BB = End->getParent();
+ BB->splitBasicBlock(End);
+ BB->getTerminator()->eraseFromParent();
+}
+
+/// Replace an unwind call to llvm.coro.end.
+static void replaceUnwindCoroEnd(CoroEndInst *End, coro::Shape &Shape,
+ Value *FramePtr, bool InResume, CallGraph *CG){
+ IRBuilder<> Builder(End);
+
+ switch (Shape.ABI) {
+ // In switch-lowering, this does nothing in the main function.
+ case coro::ABI::Switch:
+ if (!InResume)
+ return;
+ break;
+
+ // In continuation-lowering, this frees the continuation storage.
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
+ break;
+ }
+
+ // If coro.end has an associated bundle, add cleanupret instruction.
+ if (auto Bundle = End->getOperandBundle(LLVMContext::OB_funclet)) {
+ auto *FromPad = cast<CleanupPadInst>(Bundle->Inputs[0]);
+ auto *CleanupRet = Builder.CreateCleanupRet(FromPad, nullptr);
+ End->getParent()->splitBasicBlock(End);
+ CleanupRet->getParent()->getTerminator()->eraseFromParent();
+ }
+}
+
+static void replaceCoroEnd(CoroEndInst *End, coro::Shape &Shape,
+ Value *FramePtr, bool InResume, CallGraph *CG) {
+ if (End->isUnwind())
+ replaceUnwindCoroEnd(End, Shape, FramePtr, InResume, CG);
+ else
+ replaceFallthroughCoroEnd(End, Shape, FramePtr, InResume, CG);
+
+ auto &Context = End->getContext();
+ End->replaceAllUsesWith(InResume ? ConstantInt::getTrue(Context)
+ : ConstantInt::getFalse(Context));
+ End->eraseFromParent();
+}
+
+// Create an entry block for a resume function with a switch that will jump to
+// suspend points.
+static void createResumeEntryBlock(Function &F, coro::Shape &Shape) {
+ assert(Shape.ABI == coro::ABI::Switch);
+ LLVMContext &C = F.getContext();
+
+ // resume.entry:
+ // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
+ // i32 2
+ // % index = load i32, i32* %index.addr
+ // switch i32 %index, label %unreachable [
+ // i32 0, label %resume.0
+ // i32 1, label %resume.1
+ // ...
+ // ]
+
+ auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
+ auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
+
+ IRBuilder<> Builder(NewEntry);
+ auto *FramePtr = Shape.FramePtr;
+ auto *FrameTy = Shape.FrameTy;
+ auto *GepIndex = Builder.CreateStructGEP(
+ FrameTy, FramePtr, coro::Shape::SwitchFieldIndex::Index, "index.addr");
+ auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
+ auto *Switch =
+ Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
+ Shape.SwitchLowering.ResumeSwitch = Switch;
+
+ size_t SuspendIndex = 0;
+ for (auto *AnyS : Shape.CoroSuspends) {
+ auto *S = cast<CoroSuspendInst>(AnyS);
+ ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
+
+ // Replace CoroSave with a store to Index:
+ // %index.addr = getelementptr %f.frame... (index field number)
+ // store i32 0, i32* %index.addr1
+ auto *Save = S->getCoroSave();
+ Builder.SetInsertPoint(Save);
+ if (S->isFinal()) {
+ // Final suspend point is represented by storing zero in ResumeFnAddr.
+ auto *GepIndex = Builder.CreateStructGEP(FrameTy, FramePtr,
+ coro::Shape::SwitchFieldIndex::Resume,
+ "ResumeFn.addr");
+ auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
+ cast<PointerType>(GepIndex->getType())->getElementType()));
+ Builder.CreateStore(NullPtr, GepIndex);
+ } else {
+ auto *GepIndex = Builder.CreateStructGEP(
+ FrameTy, FramePtr, coro::Shape::SwitchFieldIndex::Index, "index.addr");
+ Builder.CreateStore(IndexVal, GepIndex);
+ }
+ Save->replaceAllUsesWith(ConstantTokenNone::get(C));
+ Save->eraseFromParent();
+
+ // Split block before and after coro.suspend and add a jump from an entry
+ // switch:
+ //
+ // whateverBB:
+ // whatever
+ // %0 = call i8 @llvm.coro.suspend(token none, i1 false)
+ // switch i8 %0, label %suspend[i8 0, label %resume
+ // i8 1, label %cleanup]
+ // becomes:
+ //
+ // whateverBB:
+ // whatever
+ // br label %resume.0.landing
+ //
+ // resume.0: ; <--- jump from the switch in the resume.entry
+ // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
+ // br label %resume.0.landing
+ //
+ // resume.0.landing:
+ // %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
+ // switch i8 % 1, label %suspend [i8 0, label %resume
+ // i8 1, label %cleanup]
+
+ auto *SuspendBB = S->getParent();
+ auto *ResumeBB =
+ SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
+ auto *LandingBB = ResumeBB->splitBasicBlock(
+ S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
+ Switch->addCase(IndexVal, ResumeBB);
+
+ cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
+ auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front());
+ S->replaceAllUsesWith(PN);
+ PN->addIncoming(Builder.getInt8(-1), SuspendBB);
+ PN->addIncoming(S, ResumeBB);
+
+ ++SuspendIndex;
+ }
+
+ Builder.SetInsertPoint(UnreachBB);
+ Builder.CreateUnreachable();
+
+ Shape.SwitchLowering.ResumeEntryBlock = NewEntry;
+}
+
+
+// Rewrite final suspend point handling. We do not use suspend index to
+// represent the final suspend point. Instead we zero-out ResumeFnAddr in the
+// coroutine frame, since it is undefined behavior to resume a coroutine
+// suspended at the final suspend point. Thus, in the resume function, we can
+// simply remove the last case (when coro::Shape is built, the final suspend
+// point (if present) is always the last element of CoroSuspends array).
+// In the destroy function, we add a code sequence to check if ResumeFnAddress
+// is Null, and if so, jump to the appropriate label to handle cleanup from the
+// final suspend point.
+void CoroCloner::handleFinalSuspend() {
+ assert(Shape.ABI == coro::ABI::Switch &&
+ Shape.SwitchLowering.HasFinalSuspend);
+ auto *Switch = cast<SwitchInst>(VMap[Shape.SwitchLowering.ResumeSwitch]);
+ auto FinalCaseIt = std::prev(Switch->case_end());
+ BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
+ Switch->removeCase(FinalCaseIt);
+ if (isSwitchDestroyFunction()) {
+ BasicBlock *OldSwitchBB = Switch->getParent();
+ auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
+ Builder.SetInsertPoint(OldSwitchBB->getTerminator());
+ auto *GepIndex = Builder.CreateStructGEP(Shape.FrameTy, NewFramePtr,
+ coro::Shape::SwitchFieldIndex::Resume,
+ "ResumeFn.addr");
+ auto *Load = Builder.CreateLoad(Shape.getSwitchResumePointerType(),
+ GepIndex);
+ auto *Cond = Builder.CreateIsNull(Load);
+ Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
+ OldSwitchBB->getTerminator()->eraseFromParent();
+ }
+}
+
+static Function *createCloneDeclaration(Function &OrigF, coro::Shape &Shape,
+ const Twine &Suffix,
+ Module::iterator InsertBefore) {
+ Module *M = OrigF.getParent();
+ auto *FnTy = Shape.getResumeFunctionType();
+
+ Function *NewF =
+ Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage,
+ OrigF.getName() + Suffix);
+ NewF->addParamAttr(0, Attribute::NonNull);
+ NewF->addParamAttr(0, Attribute::NoAlias);
+
+ M->getFunctionList().insert(InsertBefore, NewF);
+
+ return NewF;
+}
+
+/// Replace uses of the active llvm.coro.suspend.retcon call with the
+/// arguments to the continuation function.
+///
+/// This assumes that the builder has a meaningful insertion point.
+void CoroCloner::replaceRetconSuspendUses() {
+ assert(Shape.ABI == coro::ABI::Retcon ||
+ Shape.ABI == coro::ABI::RetconOnce);
+
+ auto NewS = VMap[ActiveSuspend];
+ if (NewS->use_empty()) return;
+
+ // Copy out all the continuation arguments after the buffer pointer into
+ // an easily-indexed data structure for convenience.
+ SmallVector<Value*, 8> Args;
+ for (auto I = std::next(NewF->arg_begin()), E = NewF->arg_end(); I != E; ++I)
+ Args.push_back(&*I);
+
+ // If the suspend returns a single scalar value, we can just do a simple
+ // replacement.
+ if (!isa<StructType>(NewS->getType())) {
+ assert(Args.size() == 1);
+ NewS->replaceAllUsesWith(Args.front());
+ return;
+ }
+
+ // Try to peephole extracts of an aggregate return.
+ for (auto UI = NewS->use_begin(), UE = NewS->use_end(); UI != UE; ) {
+ auto EVI = dyn_cast<ExtractValueInst>((UI++)->getUser());
+ if (!EVI || EVI->getNumIndices() != 1)
+ continue;
+
+ EVI->replaceAllUsesWith(Args[EVI->getIndices().front()]);
+ EVI->eraseFromParent();
+ }
+
+ // If we have no remaining uses, we're done.
+ if (NewS->use_empty()) return;
+
+ // Otherwise, we need to create an aggregate.
+ Value *Agg = UndefValue::get(NewS->getType());
+ for (size_t I = 0, E = Args.size(); I != E; ++I)
+ Agg = Builder.CreateInsertValue(Agg, Args[I], I);
+
+ NewS->replaceAllUsesWith(Agg);
+}
+
+void CoroCloner::replaceCoroSuspends() {
+ Value *SuspendResult;
+
+ switch (Shape.ABI) {
+ // In switch lowering, replace coro.suspend with the appropriate value
+ // for the type of function we're extracting.
+ // Replacing coro.suspend with (0) will result in control flow proceeding to
+ // a resume label associated with a suspend point, replacing it with (1) will
+ // result in control flow proceeding to a cleanup label associated with this
+ // suspend point.
+ case coro::ABI::Switch:
+ SuspendResult = Builder.getInt8(isSwitchDestroyFunction() ? 1 : 0);
+ break;
+
+ // In returned-continuation lowering, the arguments from earlier
+ // continuations are theoretically arbitrary, and they should have been
+ // spilled.
+ case coro::ABI::RetconOnce:
+ case coro::ABI::Retcon:
+ return;
+ }
+
+ for (AnyCoroSuspendInst *CS : Shape.CoroSuspends) {
+ // The active suspend was handled earlier.
+ if (CS == ActiveSuspend) continue;
+
+ auto *MappedCS = cast<AnyCoroSuspendInst>(VMap[CS]);
+ MappedCS->replaceAllUsesWith(SuspendResult);
+ MappedCS->eraseFromParent();
+ }
+}
+
+void CoroCloner::replaceCoroEnds() {
+ for (CoroEndInst *CE : Shape.CoroEnds) {
+ // We use a null call graph because there's no call graph node for
+ // the cloned function yet. We'll just be rebuilding that later.
+ auto NewCE = cast<CoroEndInst>(VMap[CE]);
+ replaceCoroEnd(NewCE, Shape, NewFramePtr, /*in resume*/ true, nullptr);
+ }
+}
+
+static void replaceSwiftErrorOps(Function &F, coro::Shape &Shape,
+ ValueToValueMapTy *VMap) {
+ Value *CachedSlot = nullptr;
+ auto getSwiftErrorSlot = [&](Type *ValueTy) -> Value * {
+ if (CachedSlot) {
+ assert(CachedSlot->getType()->getPointerElementType() == ValueTy &&
+ "multiple swifterror slots in function with different types");
+ return CachedSlot;
+ }
+
+ // Check if the function has a swifterror argument.
+ for (auto &Arg : F.args()) {
+ if (Arg.isSwiftError()) {
+ CachedSlot = &Arg;
+ assert(Arg.getType()->getPointerElementType() == ValueTy &&
+ "swifterror argument does not have expected type");
+ return &Arg;
+ }
+ }
+
+ // Create a swifterror alloca.
+ IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHIOrDbg());
+ auto Alloca = Builder.CreateAlloca(ValueTy);
+ Alloca->setSwiftError(true);
+
+ CachedSlot = Alloca;
+ return Alloca;
+ };
+
+ for (CallInst *Op : Shape.SwiftErrorOps) {
+ auto MappedOp = VMap ? cast<CallInst>((*VMap)[Op]) : Op;
+ IRBuilder<> Builder(MappedOp);
+
+ // If there are no arguments, this is a 'get' operation.
+ Value *MappedResult;
+ if (Op->getNumArgOperands() == 0) {
+ auto ValueTy = Op->getType();
+ auto Slot = getSwiftErrorSlot(ValueTy);
+ MappedResult = Builder.CreateLoad(ValueTy, Slot);
+ } else {
+ assert(Op->getNumArgOperands() == 1);
+ auto Value = MappedOp->getArgOperand(0);
+ auto ValueTy = Value->getType();
+ auto Slot = getSwiftErrorSlot(ValueTy);
+ Builder.CreateStore(Value, Slot);
+ MappedResult = Slot;
+ }
+
+ MappedOp->replaceAllUsesWith(MappedResult);
+ MappedOp->eraseFromParent();
+ }
+
+ // If we're updating the original function, we've invalidated SwiftErrorOps.
+ if (VMap == nullptr) {
+ Shape.SwiftErrorOps.clear();
+ }
+}
+
+void CoroCloner::replaceSwiftErrorOps() {
+ ::replaceSwiftErrorOps(*NewF, Shape, &VMap);
+}
+
+void CoroCloner::replaceEntryBlock() {
+ // In the original function, the AllocaSpillBlock is a block immediately
+ // following the allocation of the frame object which defines GEPs for
+ // all the allocas that have been moved into the frame, and it ends by
+ // branching to the original beginning of the coroutine. Make this
+ // the entry block of the cloned function.
+ auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
+ Entry->setName("entry" + Suffix);
+ Entry->moveBefore(&NewF->getEntryBlock());
+ Entry->getTerminator()->eraseFromParent();
+
+ // Clear all predecessors of the new entry block. There should be
+ // exactly one predecessor, which we created when splitting out
+ // AllocaSpillBlock to begin with.
+ assert(Entry->hasOneUse());
+ auto BranchToEntry = cast<BranchInst>(Entry->user_back());
+ assert(BranchToEntry->isUnconditional());
+ Builder.SetInsertPoint(BranchToEntry);
+ Builder.CreateUnreachable();
+ BranchToEntry->eraseFromParent();
+
+ // TODO: move any allocas into Entry that weren't moved into the frame.
+ // (Currently we move all allocas into the frame.)
+
+ // Branch from the entry to the appropriate place.
+ Builder.SetInsertPoint(Entry);
+ switch (Shape.ABI) {
+ case coro::ABI::Switch: {
+ // In switch-lowering, we built a resume-entry block in the original
+ // function. Make the entry block branch to this.
+ auto *SwitchBB =
+ cast<BasicBlock>(VMap[Shape.SwitchLowering.ResumeEntryBlock]);
+ Builder.CreateBr(SwitchBB);
+ break;
+ }
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce: {
+ // In continuation ABIs, we want to branch to immediately after the
+ // active suspend point. Earlier phases will have put the suspend in its
+ // own basic block, so just thread our jump directly to its successor.
+ auto MappedCS = cast<CoroSuspendRetconInst>(VMap[ActiveSuspend]);
+ auto Branch = cast<BranchInst>(MappedCS->getNextNode());
+ assert(Branch->isUnconditional());
+ Builder.CreateBr(Branch->getSuccessor(0));
+ break;
+ }
+ }
+}
+
+/// Derive the value of the new frame pointer.
+Value *CoroCloner::deriveNewFramePointer() {
+ // Builder should be inserting to the front of the new entry block.
+
+ switch (Shape.ABI) {
+ // In switch-lowering, the argument is the frame pointer.
+ case coro::ABI::Switch:
+ return &*NewF->arg_begin();
+
+ // In continuation-lowering, the argument is the opaque storage.
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce: {
+ Argument *NewStorage = &*NewF->arg_begin();
+ auto FramePtrTy = Shape.FrameTy->getPointerTo();
+
+ // If the storage is inline, just bitcast to the storage to the frame type.
+ if (Shape.RetconLowering.IsFrameInlineInStorage)
+ return Builder.CreateBitCast(NewStorage, FramePtrTy);
+
+ // Otherwise, load the real frame from the opaque storage.
+ auto FramePtrPtr =
+ Builder.CreateBitCast(NewStorage, FramePtrTy->getPointerTo());
+ return Builder.CreateLoad(FramePtrPtr);
+ }
+ }
+ llvm_unreachable("bad ABI");
+}
+
+/// Clone the body of the original function into a resume function of
+/// some sort.
+void CoroCloner::create() {
+ // Create the new function if we don't already have one.
+ if (!NewF) {
+ NewF = createCloneDeclaration(OrigF, Shape, Suffix,
+ OrigF.getParent()->end());
+ }
+
+ // Replace all args with undefs. The buildCoroutineFrame algorithm already
+ // rewritten access to the args that occurs after suspend points with loads
+ // and stores to/from the coroutine frame.
+ for (Argument &A : OrigF.args())
+ VMap[&A] = UndefValue::get(A.getType());
+
+ SmallVector<ReturnInst *, 4> Returns;
+
+ // Ignore attempts to change certain attributes of the function.
+ // TODO: maybe there should be a way to suppress this during cloning?
+ auto savedVisibility = NewF->getVisibility();
+ auto savedUnnamedAddr = NewF->getUnnamedAddr();
+ auto savedDLLStorageClass = NewF->getDLLStorageClass();
+
+ // NewF's linkage (which CloneFunctionInto does *not* change) might not
+ // be compatible with the visibility of OrigF (which it *does* change),
+ // so protect against that.
+ auto savedLinkage = NewF->getLinkage();
+ NewF->setLinkage(llvm::GlobalValue::ExternalLinkage);
+
+ CloneFunctionInto(NewF, &OrigF, VMap, /*ModuleLevelChanges=*/true, Returns);
+
+ NewF->setLinkage(savedLinkage);
+ NewF->setVisibility(savedVisibility);
+ NewF->setUnnamedAddr(savedUnnamedAddr);
+ NewF->setDLLStorageClass(savedDLLStorageClass);
+
+ auto &Context = NewF->getContext();
+
+ // Replace the attributes of the new function:
+ auto OrigAttrs = NewF->getAttributes();
+ auto NewAttrs = AttributeList();
+
+ switch (Shape.ABI) {
+ case coro::ABI::Switch:
+ // Bootstrap attributes by copying function attributes from the
+ // original function. This should include optimization settings and so on.
+ NewAttrs = NewAttrs.addAttributes(Context, AttributeList::FunctionIndex,
+ OrigAttrs.getFnAttributes());
+ break;
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ // If we have a continuation prototype, just use its attributes,
+ // full-stop.
+ NewAttrs = Shape.RetconLowering.ResumePrototype->getAttributes();
+ break;
+ }
+
+ // Make the frame parameter nonnull and noalias.
+ NewAttrs = NewAttrs.addParamAttribute(Context, 0, Attribute::NonNull);
+ NewAttrs = NewAttrs.addParamAttribute(Context, 0, Attribute::NoAlias);
+
+ switch (Shape.ABI) {
+ // In these ABIs, the cloned functions always return 'void', and the
+ // existing return sites are meaningless. Note that for unique
+ // continuations, this includes the returns associated with suspends;
+ // this is fine because we can't suspend twice.
+ case coro::ABI::Switch:
+ case coro::ABI::RetconOnce:
+ // Remove old returns.
+ for (ReturnInst *Return : Returns)
+ changeToUnreachable(Return, /*UseLLVMTrap=*/false);
+ break;
+
+ // With multi-suspend continuations, we'll already have eliminated the
+ // original returns and inserted returns before all the suspend points,
+ // so we want to leave any returns in place.
+ case coro::ABI::Retcon:
+ break;
+ }
+
+ NewF->setAttributes(NewAttrs);
+ NewF->setCallingConv(Shape.getResumeFunctionCC());
+
+ // Set up the new entry block.
+ replaceEntryBlock();
+
+ Builder.SetInsertPoint(&NewF->getEntryBlock().front());
+ NewFramePtr = deriveNewFramePointer();
+
+ // Remap frame pointer.
+ Value *OldFramePtr = VMap[Shape.FramePtr];
+ NewFramePtr->takeName(OldFramePtr);
+ OldFramePtr->replaceAllUsesWith(NewFramePtr);
+
+ // Remap vFrame pointer.
+ auto *NewVFrame = Builder.CreateBitCast(
+ NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame");
+ Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
+ OldVFrame->replaceAllUsesWith(NewVFrame);
+
+ switch (Shape.ABI) {
+ case coro::ABI::Switch:
+ // Rewrite final suspend handling as it is not done via switch (allows to
+ // remove final case from the switch, since it is undefined behavior to
+ // resume the coroutine suspended at the final suspend point.
+ if (Shape.SwitchLowering.HasFinalSuspend)
+ handleFinalSuspend();
+ break;
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ // Replace uses of the active suspend with the corresponding
+ // continuation-function arguments.
+ assert(ActiveSuspend != nullptr &&
+ "no active suspend when lowering a continuation-style coroutine");
+ replaceRetconSuspendUses();
+ break;
+ }
+
+ // Handle suspends.
+ replaceCoroSuspends();
+
+ // Handle swifterror.
+ replaceSwiftErrorOps();
+
+ // Remove coro.end intrinsics.
+ replaceCoroEnds();
+
+ // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
+ // to suppress deallocation code.
+ if (Shape.ABI == coro::ABI::Switch)
+ coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
+ /*Elide=*/ FKind == CoroCloner::Kind::SwitchCleanup);
+}
+
+// Create a resume clone by cloning the body of the original function, setting
+// new entry block and replacing coro.suspend an appropriate value to force
+// resume or cleanup pass for every suspend point.
+static Function *createClone(Function &F, const Twine &Suffix,
+ coro::Shape &Shape, CoroCloner::Kind FKind) {
+ CoroCloner Cloner(F, Suffix, Shape, FKind);
+ Cloner.create();
+ return Cloner.getFunction();
+}
+
+/// Remove calls to llvm.coro.end in the original function.
+static void removeCoroEnds(coro::Shape &Shape, CallGraph *CG) {
+ for (auto End : Shape.CoroEnds) {
+ replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, CG);
+ }
+}
+
+static void replaceFrameSize(coro::Shape &Shape) {
+ if (Shape.CoroSizes.empty())
+ return;
+
+ // In the same function all coro.sizes should have the same result type.
+ auto *SizeIntrin = Shape.CoroSizes.back();
+ Module *M = SizeIntrin->getModule();
+ const DataLayout &DL = M->getDataLayout();
+ auto Size = DL.getTypeAllocSize(Shape.FrameTy);
+ auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size);
+
+ for (CoroSizeInst *CS : Shape.CoroSizes) {
+ CS->replaceAllUsesWith(SizeConstant);
+ CS->eraseFromParent();
+ }
+}
+
+// Create a global constant array containing pointers to functions provided and
+// set Info parameter of CoroBegin to point at this constant. Example:
+//
+// @f.resumers = internal constant [2 x void(%f.frame*)*]
+// [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
+// define void @f() {
+// ...
+// call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
+// i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
+//
+// Assumes that all the functions have the same signature.
+static void setCoroInfo(Function &F, coro::Shape &Shape,
+ ArrayRef<Function *> Fns) {
+ // This only works under the switch-lowering ABI because coro elision
+ // only works on the switch-lowering ABI.
+ assert(Shape.ABI == coro::ABI::Switch);
+
+ SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
+ assert(!Args.empty());
+ Function *Part = *Fns.begin();
+ Module *M = Part->getParent();
+ auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
+
+ auto *ConstVal = ConstantArray::get(ArrTy, Args);
+ auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
+ GlobalVariable::PrivateLinkage, ConstVal,
+ F.getName() + Twine(".resumers"));
+
+ // Update coro.begin instruction to refer to this constant.
+ LLVMContext &C = F.getContext();
+ auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
+ Shape.getSwitchCoroId()->setInfo(BC);
+}
+
+// Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
+static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
+ Function *DestroyFn, Function *CleanupFn) {
+ assert(Shape.ABI == coro::ABI::Switch);
+
+ IRBuilder<> Builder(Shape.FramePtr->getNextNode());
+ auto *ResumeAddr = Builder.CreateStructGEP(
+ Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume,
+ "resume.addr");
+ Builder.CreateStore(ResumeFn, ResumeAddr);
+
+ Value *DestroyOrCleanupFn = DestroyFn;
+
+ CoroIdInst *CoroId = Shape.getSwitchCoroId();
+ if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
+ // If there is a CoroAlloc and it returns false (meaning we elide the
+ // allocation, use CleanupFn instead of DestroyFn).
+ DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
+ }
+
+ auto *DestroyAddr = Builder.CreateStructGEP(
+ Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Destroy,
+ "destroy.addr");
+ Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
+}
+
+static void postSplitCleanup(Function &F) {
+ removeUnreachableBlocks(F);
+
+ // For now, we do a mandatory verification step because we don't
+ // entirely trust this pass. Note that we don't want to add a verifier
+ // pass to FPM below because it will also verify all the global data.
+ verifyFunction(F);
+
+ legacy::FunctionPassManager FPM(F.getParent());
+
+ FPM.add(createSCCPPass());
+ FPM.add(createCFGSimplificationPass());
+ FPM.add(createEarlyCSEPass());
+ FPM.add(createCFGSimplificationPass());
+
+ FPM.doInitialization();
+ FPM.run(F);
+ FPM.doFinalization();
+}
+
+// Assuming we arrived at the block NewBlock from Prev instruction, store
+// PHI's incoming values in the ResolvedValues map.
+static void
+scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
+ DenseMap<Value *, Value *> &ResolvedValues) {
+ auto *PrevBB = Prev->getParent();
+ for (PHINode &PN : NewBlock->phis()) {
+ auto V = PN.getIncomingValueForBlock(PrevBB);
+ // See if we already resolved it.
+ auto VI = ResolvedValues.find(V);
+ if (VI != ResolvedValues.end())
+ V = VI->second;
+ // Remember the value.
+ ResolvedValues[&PN] = V;
+ }
+}
+
+// Replace a sequence of branches leading to a ret, with a clone of a ret
+// instruction. Suspend instruction represented by a switch, track the PHI
+// values and select the correct case successor when possible.
+static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
+ DenseMap<Value *, Value *> ResolvedValues;
+
+ Instruction *I = InitialInst;
+ while (I->isTerminator()) {
+ if (isa<ReturnInst>(I)) {
+ if (I != InitialInst)
+ ReplaceInstWithInst(InitialInst, I->clone());
+ return true;
+ }
+ if (auto *BR = dyn_cast<BranchInst>(I)) {
+ if (BR->isUnconditional()) {
+ BasicBlock *BB = BR->getSuccessor(0);
+ scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
+ I = BB->getFirstNonPHIOrDbgOrLifetime();
+ continue;
+ }
+ } else if (auto *SI = dyn_cast<SwitchInst>(I)) {
+ Value *V = SI->getCondition();
+ auto it = ResolvedValues.find(V);
+ if (it != ResolvedValues.end())
+ V = it->second;
+ if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
+ BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
+ scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
+ I = BB->getFirstNonPHIOrDbgOrLifetime();
+ continue;
+ }
+ }
+ return false;
+ }
+ return false;
+}
+
+// Add musttail to any resume instructions that is immediately followed by a
+// suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
+// for symmetrical coroutine control transfer (C++ Coroutines TS extension).
+// This transformation is done only in the resume part of the coroutine that has
+// identical signature and calling convention as the coro.resume call.
+static void addMustTailToCoroResumes(Function &F) {
+ bool changed = false;
+
+ // Collect potential resume instructions.
+ SmallVector<CallInst *, 4> Resumes;
+ for (auto &I : instructions(F))
+ if (auto *Call = dyn_cast<CallInst>(&I))
+ if (auto *CalledValue = Call->getCalledValue())
+ // CoroEarly pass replaced coro resumes with indirect calls to an
+ // address return by CoroSubFnInst intrinsic. See if it is one of those.
+ if (isa<CoroSubFnInst>(CalledValue->stripPointerCasts()))
+ Resumes.push_back(Call);
+
+ // Set musttail on those that are followed by a ret instruction.
+ for (CallInst *Call : Resumes)
+ if (simplifyTerminatorLeadingToRet(Call->getNextNode())) {
+ Call->setTailCallKind(CallInst::TCK_MustTail);
+ changed = true;
+ }
+
+ if (changed)
+ removeUnreachableBlocks(F);
+}
+
+// Coroutine has no suspend points. Remove heap allocation for the coroutine
+// frame if possible.
+static void handleNoSuspendCoroutine(coro::Shape &Shape) {
+ auto *CoroBegin = Shape.CoroBegin;
+ auto *CoroId = CoroBegin->getId();
+ auto *AllocInst = CoroId->getCoroAlloc();
+ switch (Shape.ABI) {
+ case coro::ABI::Switch: {
+ auto SwitchId = cast<CoroIdInst>(CoroId);
+ coro::replaceCoroFree(SwitchId, /*Elide=*/AllocInst != nullptr);
+ if (AllocInst) {
+ IRBuilder<> Builder(AllocInst);
+ // FIXME: Need to handle overaligned members.
+ auto *Frame = Builder.CreateAlloca(Shape.FrameTy);
+ auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
+ AllocInst->replaceAllUsesWith(Builder.getFalse());
+ AllocInst->eraseFromParent();
+ CoroBegin->replaceAllUsesWith(VFrame);
+ } else {
+ CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
+ }
+ break;
+ }
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ CoroBegin->replaceAllUsesWith(UndefValue::get(CoroBegin->getType()));
+ break;
+ }
+
+ CoroBegin->eraseFromParent();
+}
+
+// SimplifySuspendPoint needs to check that there is no calls between
+// coro_save and coro_suspend, since any of the calls may potentially resume
+// the coroutine and if that is the case we cannot eliminate the suspend point.
+static bool hasCallsInBlockBetween(Instruction *From, Instruction *To) {
+ for (Instruction *I = From; I != To; I = I->getNextNode()) {
+ // Assume that no intrinsic can resume the coroutine.
+ if (isa<IntrinsicInst>(I))
+ continue;
+
+ if (CallSite(I))
+ return true;
+ }
+ return false;
+}
+
+static bool hasCallsInBlocksBetween(BasicBlock *SaveBB, BasicBlock *ResDesBB) {
+ SmallPtrSet<BasicBlock *, 8> Set;
+ SmallVector<BasicBlock *, 8> Worklist;
+
+ Set.insert(SaveBB);
+ Worklist.push_back(ResDesBB);
+
+ // Accumulate all blocks between SaveBB and ResDesBB. Because CoroSaveIntr
+ // returns a token consumed by suspend instruction, all blocks in between
+ // will have to eventually hit SaveBB when going backwards from ResDesBB.
+ while (!Worklist.empty()) {
+ auto *BB = Worklist.pop_back_val();
+ Set.insert(BB);
+ for (auto *Pred : predecessors(BB))
+ if (Set.count(Pred) == 0)
+ Worklist.push_back(Pred);
+ }
+
+ // SaveBB and ResDesBB are checked separately in hasCallsBetween.
+ Set.erase(SaveBB);
+ Set.erase(ResDesBB);
+
+ for (auto *BB : Set)
+ if (hasCallsInBlockBetween(BB->getFirstNonPHI(), nullptr))
+ return true;
+
+ return false;
+}
+
+static bool hasCallsBetween(Instruction *Save, Instruction *ResumeOrDestroy) {
+ auto *SaveBB = Save->getParent();
+ auto *ResumeOrDestroyBB = ResumeOrDestroy->getParent();
+
+ if (SaveBB == ResumeOrDestroyBB)
+ return hasCallsInBlockBetween(Save->getNextNode(), ResumeOrDestroy);
+
+ // Any calls from Save to the end of the block?
+ if (hasCallsInBlockBetween(Save->getNextNode(), nullptr))
+ return true;
+
+ // Any calls from begging of the block up to ResumeOrDestroy?
+ if (hasCallsInBlockBetween(ResumeOrDestroyBB->getFirstNonPHI(),
+ ResumeOrDestroy))
+ return true;
+
+ // Any calls in all of the blocks between SaveBB and ResumeOrDestroyBB?
+ if (hasCallsInBlocksBetween(SaveBB, ResumeOrDestroyBB))
+ return true;
+
+ return false;
+}
+
+// If a SuspendIntrin is preceded by Resume or Destroy, we can eliminate the
+// suspend point and replace it with nornal control flow.
+static bool simplifySuspendPoint(CoroSuspendInst *Suspend,
+ CoroBeginInst *CoroBegin) {
+ Instruction *Prev = Suspend->getPrevNode();
+ if (!Prev) {
+ auto *Pred = Suspend->getParent()->getSinglePredecessor();
+ if (!Pred)
+ return false;
+ Prev = Pred->getTerminator();
+ }
+
+ CallSite CS{Prev};
+ if (!CS)
+ return false;
+
+ auto *CallInstr = CS.getInstruction();
+
+ auto *Callee = CS.getCalledValue()->stripPointerCasts();
+
+ // See if the callsite is for resumption or destruction of the coroutine.
+ auto *SubFn = dyn_cast<CoroSubFnInst>(Callee);
+ if (!SubFn)
+ return false;
+
+ // Does not refer to the current coroutine, we cannot do anything with it.
+ if (SubFn->getFrame() != CoroBegin)
+ return false;
+
+ // See if the transformation is safe. Specifically, see if there are any
+ // calls in between Save and CallInstr. They can potenitally resume the
+ // coroutine rendering this optimization unsafe.
+ auto *Save = Suspend->getCoroSave();
+ if (hasCallsBetween(Save, CallInstr))
+ return false;
+
+ // Replace llvm.coro.suspend with the value that results in resumption over
+ // the resume or cleanup path.
+ Suspend->replaceAllUsesWith(SubFn->getRawIndex());
+ Suspend->eraseFromParent();
+ Save->eraseFromParent();
+
+ // No longer need a call to coro.resume or coro.destroy.
+ if (auto *Invoke = dyn_cast<InvokeInst>(CallInstr)) {
+ BranchInst::Create(Invoke->getNormalDest(), Invoke);
+ }
+
+ // Grab the CalledValue from CS before erasing the CallInstr.
+ auto *CalledValue = CS.getCalledValue();
+ CallInstr->eraseFromParent();
+
+ // If no more users remove it. Usually it is a bitcast of SubFn.
+ if (CalledValue != SubFn && CalledValue->user_empty())
+ if (auto *I = dyn_cast<Instruction>(CalledValue))
+ I->eraseFromParent();
+
+ // Now we are good to remove SubFn.
+ if (SubFn->user_empty())
+ SubFn->eraseFromParent();
+
+ return true;
+}
+
+// Remove suspend points that are simplified.
+static void simplifySuspendPoints(coro::Shape &Shape) {
+ // Currently, the only simplification we do is switch-lowering-specific.
+ if (Shape.ABI != coro::ABI::Switch)
+ return;
+
+ auto &S = Shape.CoroSuspends;
+ size_t I = 0, N = S.size();
+ if (N == 0)
+ return;
+ while (true) {
+ if (simplifySuspendPoint(cast<CoroSuspendInst>(S[I]), Shape.CoroBegin)) {
+ if (--N == I)
+ break;
+ std::swap(S[I], S[N]);
+ continue;
+ }
+ if (++I == N)
+ break;
+ }
+ S.resize(N);
+}
+
+static void splitSwitchCoroutine(Function &F, coro::Shape &Shape,
+ SmallVectorImpl<Function *> &Clones) {
+ assert(Shape.ABI == coro::ABI::Switch);
+
+ createResumeEntryBlock(F, Shape);
+ auto ResumeClone = createClone(F, ".resume", Shape,
+ CoroCloner::Kind::SwitchResume);
+ auto DestroyClone = createClone(F, ".destroy", Shape,
+ CoroCloner::Kind::SwitchUnwind);
+ auto CleanupClone = createClone(F, ".cleanup", Shape,
+ CoroCloner::Kind::SwitchCleanup);
+
+ postSplitCleanup(*ResumeClone);
+ postSplitCleanup(*DestroyClone);
+ postSplitCleanup(*CleanupClone);
+
+ addMustTailToCoroResumes(*ResumeClone);
+
+ // Store addresses resume/destroy/cleanup functions in the coroutine frame.
+ updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
+
+ assert(Clones.empty());
+ Clones.push_back(ResumeClone);
+ Clones.push_back(DestroyClone);
+ Clones.push_back(CleanupClone);
+
+ // Create a constant array referring to resume/destroy/clone functions pointed
+ // by the last argument of @llvm.coro.info, so that CoroElide pass can
+ // determined correct function to call.
+ setCoroInfo(F, Shape, Clones);
+}
+
+static void splitRetconCoroutine(Function &F, coro::Shape &Shape,
+ SmallVectorImpl<Function *> &Clones) {
+ assert(Shape.ABI == coro::ABI::Retcon ||
+ Shape.ABI == coro::ABI::RetconOnce);
+ assert(Clones.empty());
+
+ // Reset various things that the optimizer might have decided it
+ // "knows" about the coroutine function due to not seeing a return.
+ F.removeFnAttr(Attribute::NoReturn);
+ F.removeAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
+ F.removeAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
+
+ // Allocate the frame.
+ auto *Id = cast<AnyCoroIdRetconInst>(Shape.CoroBegin->getId());
+ Value *RawFramePtr;
+ if (Shape.RetconLowering.IsFrameInlineInStorage) {
+ RawFramePtr = Id->getStorage();
+ } else {
+ IRBuilder<> Builder(Id);
+
+ // Determine the size of the frame.
+ const DataLayout &DL = F.getParent()->getDataLayout();
+ auto Size = DL.getTypeAllocSize(Shape.FrameTy);
+
+ // Allocate. We don't need to update the call graph node because we're
+ // going to recompute it from scratch after splitting.
+ RawFramePtr = Shape.emitAlloc(Builder, Builder.getInt64(Size), nullptr);
+ RawFramePtr =
+ Builder.CreateBitCast(RawFramePtr, Shape.CoroBegin->getType());
+
+ // Stash the allocated frame pointer in the continuation storage.
+ auto Dest = Builder.CreateBitCast(Id->getStorage(),
+ RawFramePtr->getType()->getPointerTo());
+ Builder.CreateStore(RawFramePtr, Dest);
+ }
+
+ // Map all uses of llvm.coro.begin to the allocated frame pointer.
+ {
+ // Make sure we don't invalidate Shape.FramePtr.
+ TrackingVH<Instruction> Handle(Shape.FramePtr);
+ Shape.CoroBegin->replaceAllUsesWith(RawFramePtr);
+ Shape.FramePtr = Handle.getValPtr();
+ }
+
+ // Create a unique return block.
+ BasicBlock *ReturnBB = nullptr;
+ SmallVector<PHINode *, 4> ReturnPHIs;
+
+ // Create all the functions in order after the main function.
+ auto NextF = std::next(F.getIterator());
+
+ // Create a continuation function for each of the suspend points.
+ Clones.reserve(Shape.CoroSuspends.size());
+ for (size_t i = 0, e = Shape.CoroSuspends.size(); i != e; ++i) {
+ auto Suspend = cast<CoroSuspendRetconInst>(Shape.CoroSuspends[i]);
+
+ // Create the clone declaration.
+ auto Continuation =
+ createCloneDeclaration(F, Shape, ".resume." + Twine(i), NextF);
+ Clones.push_back(Continuation);
+
+ // Insert a branch to the unified return block immediately before
+ // the suspend point.
+ auto SuspendBB = Suspend->getParent();
+ auto NewSuspendBB = SuspendBB->splitBasicBlock(Suspend);
+ auto Branch = cast<BranchInst>(SuspendBB->getTerminator());
+
+ // Create the unified return block.
+ if (!ReturnBB) {
+ // Place it before the first suspend.
+ ReturnBB = BasicBlock::Create(F.getContext(), "coro.return", &F,
+ NewSuspendBB);
+ Shape.RetconLowering.ReturnBlock = ReturnBB;
+
+ IRBuilder<> Builder(ReturnBB);
+
+ // Create PHIs for all the return values.
+ assert(ReturnPHIs.empty());
+
+ // First, the continuation.
+ ReturnPHIs.push_back(Builder.CreatePHI(Continuation->getType(),
+ Shape.CoroSuspends.size()));
+
+ // Next, all the directly-yielded values.
+ for (auto ResultTy : Shape.getRetconResultTypes())
+ ReturnPHIs.push_back(Builder.CreatePHI(ResultTy,
+ Shape.CoroSuspends.size()));
+
+ // Build the return value.
+ auto RetTy = F.getReturnType();
+
+ // Cast the continuation value if necessary.
+ // We can't rely on the types matching up because that type would
+ // have to be infinite.
+ auto CastedContinuationTy =
+ (ReturnPHIs.size() == 1 ? RetTy : RetTy->getStructElementType(0));
+ auto *CastedContinuation =
+ Builder.CreateBitCast(ReturnPHIs[0], CastedContinuationTy);
+
+ Value *RetV;
+ if (ReturnPHIs.size() == 1) {
+ RetV = CastedContinuation;
+ } else {
+ RetV = UndefValue::get(RetTy);
+ RetV = Builder.CreateInsertValue(RetV, CastedContinuation, 0);
+ for (size_t I = 1, E = ReturnPHIs.size(); I != E; ++I)
+ RetV = Builder.CreateInsertValue(RetV, ReturnPHIs[I], I);
+ }
+
+ Builder.CreateRet(RetV);
+ }
+
+ // Branch to the return block.
+ Branch->setSuccessor(0, ReturnBB);
+ ReturnPHIs[0]->addIncoming(Continuation, SuspendBB);
+ size_t NextPHIIndex = 1;
+ for (auto &VUse : Suspend->value_operands())
+ ReturnPHIs[NextPHIIndex++]->addIncoming(&*VUse, SuspendBB);
+ assert(NextPHIIndex == ReturnPHIs.size());
+ }
+
+ assert(Clones.size() == Shape.CoroSuspends.size());
+ for (size_t i = 0, e = Shape.CoroSuspends.size(); i != e; ++i) {
+ auto Suspend = Shape.CoroSuspends[i];
+ auto Clone = Clones[i];
+
+ CoroCloner(F, "resume." + Twine(i), Shape, Clone, Suspend).create();
+ }
+}
+
+namespace {
+ class PrettyStackTraceFunction : public PrettyStackTraceEntry {
+ Function &F;
+ public:
+ PrettyStackTraceFunction(Function &F) : F(F) {}
+ void print(raw_ostream &OS) const override {
+ OS << "While splitting coroutine ";
+ F.printAsOperand(OS, /*print type*/ false, F.getParent());
+ OS << "\n";
+ }
+ };
+}
+
+static void splitCoroutine(Function &F, coro::Shape &Shape,
+ SmallVectorImpl<Function *> &Clones) {
+ switch (Shape.ABI) {
+ case coro::ABI::Switch:
+ return splitSwitchCoroutine(F, Shape, Clones);
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce:
+ return splitRetconCoroutine(F, Shape, Clones);
+ }
+ llvm_unreachable("bad ABI kind");
+}
+
+static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) {
+ PrettyStackTraceFunction prettyStackTrace(F);
+
+ // The suspend-crossing algorithm in buildCoroutineFrame get tripped
+ // up by uses in unreachable blocks, so remove them as a first pass.
+ removeUnreachableBlocks(F);
+
+ coro::Shape Shape(F);
+ if (!Shape.CoroBegin)
+ return;
+
+ simplifySuspendPoints(Shape);
+ buildCoroutineFrame(F, Shape);
+ replaceFrameSize(Shape);
+
+ SmallVector<Function*, 4> Clones;
+
+ // If there are no suspend points, no split required, just remove
+ // the allocation and deallocation blocks, they are not needed.
+ if (Shape.CoroSuspends.empty()) {
+ handleNoSuspendCoroutine(Shape);
+ } else {
+ splitCoroutine(F, Shape, Clones);
+ }
+
+ // Replace all the swifterror operations in the original function.
+ // This invalidates SwiftErrorOps in the Shape.
+ replaceSwiftErrorOps(F, Shape, nullptr);
+
+ removeCoroEnds(Shape, &CG);
+ postSplitCleanup(F);
+
+ // Update call graph and add the functions we created to the SCC.
+ coro::updateCallGraph(F, Clones, CG, SCC);
+}
+
+// When we see the coroutine the first time, we insert an indirect call to a
+// devirt trigger function and mark the coroutine that it is now ready for
+// split.
+static void prepareForSplit(Function &F, CallGraph &CG) {
+ Module &M = *F.getParent();
+ LLVMContext &Context = F.getContext();
+#ifndef NDEBUG
+ Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN);
+ assert(DevirtFn && "coro.devirt.trigger function not found");
+#endif
+
+ F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
+
+ // Insert an indirect call sequence that will be devirtualized by CoroElide
+ // pass:
+ // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1)
+ // %1 = bitcast i8* %0 to void(i8*)*
+ // call void %1(i8* null)
+ coro::LowererBase Lowerer(M);
+ Instruction *InsertPt = F.getEntryBlock().getTerminator();
+ auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(Context));
+ auto *DevirtFnAddr =
+ Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt);
+ FunctionType *FnTy = FunctionType::get(Type::getVoidTy(Context),
+ {Type::getInt8PtrTy(Context)}, false);
+ auto *IndirectCall = CallInst::Create(FnTy, DevirtFnAddr, Null, "", InsertPt);
+
+ // Update CG graph with an indirect call we just added.
+ CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode());
+}
+
+// Make sure that there is a devirtualization trigger function that CoroSplit
+// pass uses the force restart CGSCC pipeline. If devirt trigger function is not
+// found, we will create one and add it to the current SCC.
+static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) {
+ Module &M = CG.getModule();
+ if (M.getFunction(CORO_DEVIRT_TRIGGER_FN))
+ return;
+
+ LLVMContext &C = M.getContext();
+ auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C),
+ /*isVarArg=*/false);
+ Function *DevirtFn =
+ Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
+ CORO_DEVIRT_TRIGGER_FN, &M);
+ DevirtFn->addFnAttr(Attribute::AlwaysInline);
+ auto *Entry = BasicBlock::Create(C, "entry", DevirtFn);
+ ReturnInst::Create(C, Entry);
+
+ auto *Node = CG.getOrInsertFunction(DevirtFn);
+
+ SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
+ Nodes.push_back(Node);
+ SCC.initialize(Nodes);
+}
+
+/// Replace a call to llvm.coro.prepare.retcon.
+static void replacePrepare(CallInst *Prepare, CallGraph &CG) {
+ auto CastFn = Prepare->getArgOperand(0); // as an i8*
+ auto Fn = CastFn->stripPointerCasts(); // as its original type
+
+ // Find call graph nodes for the preparation.
+ CallGraphNode *PrepareUserNode = nullptr, *FnNode = nullptr;
+ if (auto ConcreteFn = dyn_cast<Function>(Fn)) {
+ PrepareUserNode = CG[Prepare->getFunction()];
+ FnNode = CG[ConcreteFn];
+ }
+
+ // Attempt to peephole this pattern:
+ // %0 = bitcast [[TYPE]] @some_function to i8*
+ // %1 = call @llvm.coro.prepare.retcon(i8* %0)
+ // %2 = bitcast %1 to [[TYPE]]
+ // ==>
+ // %2 = @some_function
+ for (auto UI = Prepare->use_begin(), UE = Prepare->use_end();
+ UI != UE; ) {
+ // Look for bitcasts back to the original function type.
+ auto *Cast = dyn_cast<BitCastInst>((UI++)->getUser());
+ if (!Cast || Cast->getType() != Fn->getType()) continue;
+
+ // Check whether the replacement will introduce new direct calls.
+ // If so, we'll need to update the call graph.
+ if (PrepareUserNode) {
+ for (auto &Use : Cast->uses()) {
+ if (auto *CB = dyn_cast<CallBase>(Use.getUser())) {
+ if (!CB->isCallee(&Use))
+ continue;
+ PrepareUserNode->removeCallEdgeFor(*CB);
+ PrepareUserNode->addCalledFunction(CB, FnNode);
+ }
+ }
+ }
+
+ // Replace and remove the cast.
+ Cast->replaceAllUsesWith(Fn);
+ Cast->eraseFromParent();
+ }
+
+ // Replace any remaining uses with the function as an i8*.
+ // This can never directly be a callee, so we don't need to update CG.
+ Prepare->replaceAllUsesWith(CastFn);
+ Prepare->eraseFromParent();
+
+ // Kill dead bitcasts.
+ while (auto *Cast = dyn_cast<BitCastInst>(CastFn)) {
+ if (!Cast->use_empty()) break;
+ CastFn = Cast->getOperand(0);
+ Cast->eraseFromParent();
+ }
+}
+
+/// Remove calls to llvm.coro.prepare.retcon, a barrier meant to prevent
+/// IPO from operating on calls to a retcon coroutine before it's been
+/// split. This is only safe to do after we've split all retcon
+/// coroutines in the module. We can do that this in this pass because
+/// this pass does promise to split all retcon coroutines (as opposed to
+/// switch coroutines, which are lowered in multiple stages).
+static bool replaceAllPrepares(Function *PrepareFn, CallGraph &CG) {
+ bool Changed = false;
+ for (auto PI = PrepareFn->use_begin(), PE = PrepareFn->use_end();
+ PI != PE; ) {
+ // Intrinsics can only be used in calls.
+ auto *Prepare = cast<CallInst>((PI++)->getUser());
+ replacePrepare(Prepare, CG);
+ Changed = true;
+ }
+
+ return Changed;
+}
+
+//===----------------------------------------------------------------------===//
+// Top Level Driver
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct CoroSplit : public CallGraphSCCPass {
+ static char ID; // Pass identification, replacement for typeid
+
+ CoroSplit() : CallGraphSCCPass(ID) {
+ initializeCoroSplitPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool Run = false;
+
+ // A coroutine is identified by the presence of coro.begin intrinsic, if
+ // we don't have any, this pass has nothing to do.
+ bool doInitialization(CallGraph &CG) override {
+ Run = coro::declaresIntrinsics(CG.getModule(),
+ {"llvm.coro.begin",
+ "llvm.coro.prepare.retcon"});
+ return CallGraphSCCPass::doInitialization(CG);
+ }
+
+ bool runOnSCC(CallGraphSCC &SCC) override {
+ if (!Run)
+ return false;
+
+ // Check for uses of llvm.coro.prepare.retcon.
+ auto PrepareFn =
+ SCC.getCallGraph().getModule().getFunction("llvm.coro.prepare.retcon");
+ if (PrepareFn && PrepareFn->use_empty())
+ PrepareFn = nullptr;
+
+ // Find coroutines for processing.
+ SmallVector<Function *, 4> Coroutines;
+ for (CallGraphNode *CGN : SCC)
+ if (auto *F = CGN->getFunction())
+ if (F->hasFnAttribute(CORO_PRESPLIT_ATTR))
+ Coroutines.push_back(F);
+
+ if (Coroutines.empty() && !PrepareFn)
+ return false;
+
+ CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
+
+ if (Coroutines.empty())
+ return replaceAllPrepares(PrepareFn, CG);
+
+ createDevirtTriggerFunc(CG, SCC);
+
+ // Split all the coroutines.
+ for (Function *F : Coroutines) {
+ Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR);
+ StringRef Value = Attr.getValueAsString();
+ LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName()
+ << "' state: " << Value << "\n");
+ if (Value == UNPREPARED_FOR_SPLIT) {
+ prepareForSplit(*F, CG);
+ continue;
+ }
+ F->removeFnAttr(CORO_PRESPLIT_ATTR);
+ splitCoroutine(*F, CG, SCC);
+ }
+
+ if (PrepareFn)
+ replaceAllPrepares(PrepareFn, CG);
+
+ return true;
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ CallGraphSCCPass::getAnalysisUsage(AU);
+ }
+
+ StringRef getPassName() const override { return "Coroutine Splitting"; }
+};
+
+} // end anonymous namespace
+
+char CoroSplit::ID = 0;
+
+INITIALIZE_PASS_BEGIN(
+ CoroSplit, "coro-split",
+ "Split coroutine into a set of functions driving its state machine", false,
+ false)
+INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
+INITIALIZE_PASS_END(
+ CoroSplit, "coro-split",
+ "Split coroutine into a set of functions driving its state machine", false,
+ false)
+
+Pass *llvm::createCoroSplitPass() { return new CoroSplit(); }
diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
new file mode 100644
index 000000000000..f39483b27518
--- /dev/null
+++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp
@@ -0,0 +1,650 @@
+//===- Coroutines.cpp -----------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the common infrastructure for Coroutine Passes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Coroutines.h"
+#include "llvm-c/Transforms/Coroutines.h"
+#include "CoroInstr.h"
+#include "CoroInternal.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/CallGraph.h"
+#include "llvm/Analysis/CallGraphSCCPass.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Transforms/IPO.h"
+#include "llvm/Transforms/IPO/PassManagerBuilder.h"
+#include <cassert>
+#include <cstddef>
+#include <utility>
+
+using namespace llvm;
+
+void llvm::initializeCoroutines(PassRegistry &Registry) {
+ initializeCoroEarlyPass(Registry);
+ initializeCoroSplitPass(Registry);
+ initializeCoroElidePass(Registry);
+ initializeCoroCleanupPass(Registry);
+}
+
+static void addCoroutineOpt0Passes(const PassManagerBuilder &Builder,
+ legacy::PassManagerBase &PM) {
+ PM.add(createCoroSplitPass());
+ PM.add(createCoroElidePass());
+
+ PM.add(createBarrierNoopPass());
+ PM.add(createCoroCleanupPass());
+}
+
+static void addCoroutineEarlyPasses(const PassManagerBuilder &Builder,
+ legacy::PassManagerBase &PM) {
+ PM.add(createCoroEarlyPass());
+}
+
+static void addCoroutineScalarOptimizerPasses(const PassManagerBuilder &Builder,
+ legacy::PassManagerBase &PM) {
+ PM.add(createCoroElidePass());
+}
+
+static void addCoroutineSCCPasses(const PassManagerBuilder &Builder,
+ legacy::PassManagerBase &PM) {
+ PM.add(createCoroSplitPass());
+}
+
+static void addCoroutineOptimizerLastPasses(const PassManagerBuilder &Builder,
+ legacy::PassManagerBase &PM) {
+ PM.add(createCoroCleanupPass());
+}
+
+void llvm::addCoroutinePassesToExtensionPoints(PassManagerBuilder &Builder) {
+ Builder.addExtension(PassManagerBuilder::EP_EarlyAsPossible,
+ addCoroutineEarlyPasses);
+ Builder.addExtension(PassManagerBuilder::EP_EnabledOnOptLevel0,
+ addCoroutineOpt0Passes);
+ Builder.addExtension(PassManagerBuilder::EP_CGSCCOptimizerLate,
+ addCoroutineSCCPasses);
+ Builder.addExtension(PassManagerBuilder::EP_ScalarOptimizerLate,
+ addCoroutineScalarOptimizerPasses);
+ Builder.addExtension(PassManagerBuilder::EP_OptimizerLast,
+ addCoroutineOptimizerLastPasses);
+}
+
+// Construct the lowerer base class and initialize its members.
+coro::LowererBase::LowererBase(Module &M)
+ : TheModule(M), Context(M.getContext()),
+ Int8Ptr(Type::getInt8PtrTy(Context)),
+ ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
+ /*isVarArg=*/false)),
+ NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
+
+// Creates a sequence of instructions to obtain a resume function address using
+// llvm.coro.subfn.addr. It generates the following sequence:
+//
+// call i8* @llvm.coro.subfn.addr(i8* %Arg, i8 %index)
+// bitcast i8* %2 to void(i8*)*
+
+Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
+ Instruction *InsertPt) {
+ auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
+ auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
+
+ assert(Index >= CoroSubFnInst::IndexFirst &&
+ Index < CoroSubFnInst::IndexLast &&
+ "makeSubFnCall: Index value out of range");
+ auto *Call = CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt);
+
+ auto *Bitcast =
+ new BitCastInst(Call, ResumeFnType->getPointerTo(), "", InsertPt);
+ return Bitcast;
+}
+
+#ifndef NDEBUG
+static bool isCoroutineIntrinsicName(StringRef Name) {
+ // NOTE: Must be sorted!
+ static const char *const CoroIntrinsics[] = {
+ "llvm.coro.alloc",
+ "llvm.coro.begin",
+ "llvm.coro.destroy",
+ "llvm.coro.done",
+ "llvm.coro.end",
+ "llvm.coro.frame",
+ "llvm.coro.free",
+ "llvm.coro.id",
+ "llvm.coro.id.retcon",
+ "llvm.coro.id.retcon.once",
+ "llvm.coro.noop",
+ "llvm.coro.param",
+ "llvm.coro.prepare.retcon",
+ "llvm.coro.promise",
+ "llvm.coro.resume",
+ "llvm.coro.save",
+ "llvm.coro.size",
+ "llvm.coro.subfn.addr",
+ "llvm.coro.suspend",
+ "llvm.coro.suspend.retcon",
+ };
+ return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1;
+}
+#endif
+
+// Verifies if a module has named values listed. Also, in debug mode verifies
+// that names are intrinsic names.
+bool coro::declaresIntrinsics(Module &M,
+ std::initializer_list<StringRef> List) {
+ for (StringRef Name : List) {
+ assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
+ if (M.getNamedValue(Name))
+ return true;
+ }
+
+ return false;
+}
+
+// Replace all coro.frees associated with the provided CoroId either with 'null'
+// if Elide is true and with its frame parameter otherwise.
+void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
+ SmallVector<CoroFreeInst *, 4> CoroFrees;
+ for (User *U : CoroId->users())
+ if (auto CF = dyn_cast<CoroFreeInst>(U))
+ CoroFrees.push_back(CF);
+
+ if (CoroFrees.empty())
+ return;
+
+ Value *Replacement =
+ Elide ? ConstantPointerNull::get(Type::getInt8PtrTy(CoroId->getContext()))
+ : CoroFrees.front()->getFrame();
+
+ for (CoroFreeInst *CF : CoroFrees) {
+ CF->replaceAllUsesWith(Replacement);
+ CF->eraseFromParent();
+ }
+}
+
+// FIXME: This code is stolen from CallGraph::addToCallGraph(Function *F), which
+// happens to be private. It is better for this functionality exposed by the
+// CallGraph.
+static void buildCGN(CallGraph &CG, CallGraphNode *Node) {
+ Function *F = Node->getFunction();
+
+ // Look for calls by this function.
+ for (Instruction &I : instructions(F))
+ if (auto *Call = dyn_cast<CallBase>(&I)) {
+ const Function *Callee = Call->getCalledFunction();
+ if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID()))
+ // Indirect calls of intrinsics are not allowed so no need to check.
+ // We can be more precise here by using TargetArg returned by
+ // Intrinsic::isLeaf.
+ Node->addCalledFunction(Call, CG.getCallsExternalNode());
+ else if (!Callee->isIntrinsic())
+ Node->addCalledFunction(Call, CG.getOrInsertFunction(Callee));
+ }
+}
+
+// Rebuild CGN after we extracted parts of the code from ParentFunc into
+// NewFuncs. Builds CGNs for the NewFuncs and adds them to the current SCC.
+void coro::updateCallGraph(Function &ParentFunc, ArrayRef<Function *> NewFuncs,
+ CallGraph &CG, CallGraphSCC &SCC) {
+ // Rebuild CGN from scratch for the ParentFunc
+ auto *ParentNode = CG[&ParentFunc];
+ ParentNode->removeAllCalledFunctions();
+ buildCGN(CG, ParentNode);
+
+ SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
+
+ for (Function *F : NewFuncs) {
+ CallGraphNode *Callee = CG.getOrInsertFunction(F);
+ Nodes.push_back(Callee);
+ buildCGN(CG, Callee);
+ }
+
+ SCC.initialize(Nodes);
+}
+
+static void clear(coro::Shape &Shape) {
+ Shape.CoroBegin = nullptr;
+ Shape.CoroEnds.clear();
+ Shape.CoroSizes.clear();
+ Shape.CoroSuspends.clear();
+
+ Shape.FrameTy = nullptr;
+ Shape.FramePtr = nullptr;
+ Shape.AllocaSpillBlock = nullptr;
+}
+
+static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
+ CoroSuspendInst *SuspendInst) {
+ Module *M = SuspendInst->getModule();
+ auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save);
+ auto *SaveInst =
+ cast<CoroSaveInst>(CallInst::Create(Fn, CoroBegin, "", SuspendInst));
+ assert(!SuspendInst->getCoroSave());
+ SuspendInst->setArgOperand(0, SaveInst);
+ return SaveInst;
+}
+
+// Collect "interesting" coroutine intrinsics.
+void coro::Shape::buildFrom(Function &F) {
+ bool HasFinalSuspend = false;
+ size_t FinalSuspendIndex = 0;
+ clear(*this);
+ SmallVector<CoroFrameInst *, 8> CoroFrames;
+ SmallVector<CoroSaveInst *, 2> UnusedCoroSaves;
+
+ for (Instruction &I : instructions(F)) {
+ if (auto II = dyn_cast<IntrinsicInst>(&I)) {
+ switch (II->getIntrinsicID()) {
+ default:
+ continue;
+ case Intrinsic::coro_size:
+ CoroSizes.push_back(cast<CoroSizeInst>(II));
+ break;
+ case Intrinsic::coro_frame:
+ CoroFrames.push_back(cast<CoroFrameInst>(II));
+ break;
+ case Intrinsic::coro_save:
+ // After optimizations, coro_suspends using this coro_save might have
+ // been removed, remember orphaned coro_saves to remove them later.
+ if (II->use_empty())
+ UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
+ break;
+ case Intrinsic::coro_suspend_retcon: {
+ auto Suspend = cast<CoroSuspendRetconInst>(II);
+ CoroSuspends.push_back(Suspend);
+ break;
+ }
+ case Intrinsic::coro_suspend: {
+ auto Suspend = cast<CoroSuspendInst>(II);
+ CoroSuspends.push_back(Suspend);
+ if (Suspend->isFinal()) {
+ if (HasFinalSuspend)
+ report_fatal_error(
+ "Only one suspend point can be marked as final");
+ HasFinalSuspend = true;
+ FinalSuspendIndex = CoroSuspends.size() - 1;
+ }
+ break;
+ }
+ case Intrinsic::coro_begin: {
+ auto CB = cast<CoroBeginInst>(II);
+
+ // Ignore coro id's that aren't pre-split.
+ auto Id = dyn_cast<CoroIdInst>(CB->getId());
+ if (Id && !Id->getInfo().isPreSplit())
+ break;
+
+ if (CoroBegin)
+ report_fatal_error(
+ "coroutine should have exactly one defining @llvm.coro.begin");
+ CB->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
+ CB->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
+ CB->removeAttribute(AttributeList::FunctionIndex,
+ Attribute::NoDuplicate);
+ CoroBegin = CB;
+ break;
+ }
+ case Intrinsic::coro_end:
+ CoroEnds.push_back(cast<CoroEndInst>(II));
+ if (CoroEnds.back()->isFallthrough()) {
+ // Make sure that the fallthrough coro.end is the first element in the
+ // CoroEnds vector.
+ if (CoroEnds.size() > 1) {
+ if (CoroEnds.front()->isFallthrough())
+ report_fatal_error(
+ "Only one coro.end can be marked as fallthrough");
+ std::swap(CoroEnds.front(), CoroEnds.back());
+ }
+ }
+ break;
+ }
+ }
+ }
+
+ // If for some reason, we were not able to find coro.begin, bailout.
+ if (!CoroBegin) {
+ // Replace coro.frame which are supposed to be lowered to the result of
+ // coro.begin with undef.
+ auto *Undef = UndefValue::get(Type::getInt8PtrTy(F.getContext()));
+ for (CoroFrameInst *CF : CoroFrames) {
+ CF->replaceAllUsesWith(Undef);
+ CF->eraseFromParent();
+ }
+
+ // Replace all coro.suspend with undef and remove related coro.saves if
+ // present.
+ for (AnyCoroSuspendInst *CS : CoroSuspends) {
+ CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
+ CS->eraseFromParent();
+ if (auto *CoroSave = CS->getCoroSave())
+ CoroSave->eraseFromParent();
+ }
+
+ // Replace all coro.ends with unreachable instruction.
+ for (CoroEndInst *CE : CoroEnds)
+ changeToUnreachable(CE, /*UseLLVMTrap=*/false);
+
+ return;
+ }
+
+ auto Id = CoroBegin->getId();
+ switch (auto IdIntrinsic = Id->getIntrinsicID()) {
+ case Intrinsic::coro_id: {
+ auto SwitchId = cast<CoroIdInst>(Id);
+ this->ABI = coro::ABI::Switch;
+ this->SwitchLowering.HasFinalSuspend = HasFinalSuspend;
+ this->SwitchLowering.ResumeSwitch = nullptr;
+ this->SwitchLowering.PromiseAlloca = SwitchId->getPromise();
+ this->SwitchLowering.ResumeEntryBlock = nullptr;
+
+ for (auto AnySuspend : CoroSuspends) {
+ auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
+ if (!Suspend) {
+#ifndef NDEBUG
+ AnySuspend->dump();
+#endif
+ report_fatal_error("coro.id must be paired with coro.suspend");
+ }
+
+ if (!Suspend->getCoroSave())
+ createCoroSave(CoroBegin, Suspend);
+ }
+ break;
+ }
+
+ case Intrinsic::coro_id_retcon:
+ case Intrinsic::coro_id_retcon_once: {
+ auto ContinuationId = cast<AnyCoroIdRetconInst>(Id);
+ ContinuationId->checkWellFormed();
+ this->ABI = (IdIntrinsic == Intrinsic::coro_id_retcon
+ ? coro::ABI::Retcon
+ : coro::ABI::RetconOnce);
+ auto Prototype = ContinuationId->getPrototype();
+ this->RetconLowering.ResumePrototype = Prototype;
+ this->RetconLowering.Alloc = ContinuationId->getAllocFunction();
+ this->RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
+ this->RetconLowering.ReturnBlock = nullptr;
+ this->RetconLowering.IsFrameInlineInStorage = false;
+
+ // Determine the result value types, and make sure they match up with
+ // the values passed to the suspends.
+ auto ResultTys = getRetconResultTypes();
+ auto ResumeTys = getRetconResumeTypes();
+
+ for (auto AnySuspend : CoroSuspends) {
+ auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
+ if (!Suspend) {
+#ifndef NDEBUG
+ AnySuspend->dump();
+#endif
+ report_fatal_error("coro.id.retcon.* must be paired with "
+ "coro.suspend.retcon");
+ }
+
+ // Check that the argument types of the suspend match the results.
+ auto SI = Suspend->value_begin(), SE = Suspend->value_end();
+ auto RI = ResultTys.begin(), RE = ResultTys.end();
+ for (; SI != SE && RI != RE; ++SI, ++RI) {
+ auto SrcTy = (*SI)->getType();
+ if (SrcTy != *RI) {
+ // The optimizer likes to eliminate bitcasts leading into variadic
+ // calls, but that messes with our invariants. Re-insert the
+ // bitcast and ignore this type mismatch.
+ if (CastInst::isBitCastable(SrcTy, *RI)) {
+ auto BCI = new BitCastInst(*SI, *RI, "", Suspend);
+ SI->set(BCI);
+ continue;
+ }
+
+#ifndef NDEBUG
+ Suspend->dump();
+ Prototype->getFunctionType()->dump();
+#endif
+ report_fatal_error("argument to coro.suspend.retcon does not "
+ "match corresponding prototype function result");
+ }
+ }
+ if (SI != SE || RI != RE) {
+#ifndef NDEBUG
+ Suspend->dump();
+ Prototype->getFunctionType()->dump();
+#endif
+ report_fatal_error("wrong number of arguments to coro.suspend.retcon");
+ }
+
+ // Check that the result type of the suspend matches the resume types.
+ Type *SResultTy = Suspend->getType();
+ ArrayRef<Type*> SuspendResultTys;
+ if (SResultTy->isVoidTy()) {
+ // leave as empty array
+ } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
+ SuspendResultTys = SResultStructTy->elements();
+ } else {
+ // forms an ArrayRef using SResultTy, be careful
+ SuspendResultTys = SResultTy;
+ }
+ if (SuspendResultTys.size() != ResumeTys.size()) {
+#ifndef NDEBUG
+ Suspend->dump();
+ Prototype->getFunctionType()->dump();
+#endif
+ report_fatal_error("wrong number of results from coro.suspend.retcon");
+ }
+ for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
+ if (SuspendResultTys[I] != ResumeTys[I]) {
+#ifndef NDEBUG
+ Suspend->dump();
+ Prototype->getFunctionType()->dump();
+#endif
+ report_fatal_error("result from coro.suspend.retcon does not "
+ "match corresponding prototype function param");
+ }
+ }
+ }
+ break;
+ }
+
+ default:
+ llvm_unreachable("coro.begin is not dependent on a coro.id call");
+ }
+
+ // The coro.free intrinsic is always lowered to the result of coro.begin.
+ for (CoroFrameInst *CF : CoroFrames) {
+ CF->replaceAllUsesWith(CoroBegin);
+ CF->eraseFromParent();
+ }
+
+ // Move final suspend to be the last element in the CoroSuspends vector.
+ if (ABI == coro::ABI::Switch &&
+ SwitchLowering.HasFinalSuspend &&
+ FinalSuspendIndex != CoroSuspends.size() - 1)
+ std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
+
+ // Remove orphaned coro.saves.
+ for (CoroSaveInst *CoroSave : UnusedCoroSaves)
+ CoroSave->eraseFromParent();
+}
+
+static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
+ Call->setCallingConv(Callee->getCallingConv());
+ // TODO: attributes?
+}
+
+static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
+ if (CG)
+ (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]);
+}
+
+Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
+ CallGraph *CG) const {
+ switch (ABI) {
+ case coro::ABI::Switch:
+ llvm_unreachable("can't allocate memory in coro switch-lowering");
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce: {
+ auto Alloc = RetconLowering.Alloc;
+ Size = Builder.CreateIntCast(Size,
+ Alloc->getFunctionType()->getParamType(0),
+ /*is signed*/ false);
+ auto *Call = Builder.CreateCall(Alloc, Size);
+ propagateCallAttrsFromCallee(Call, Alloc);
+ addCallToCallGraph(CG, Call, Alloc);
+ return Call;
+ }
+ }
+ llvm_unreachable("Unknown coro::ABI enum");
+}
+
+void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
+ CallGraph *CG) const {
+ switch (ABI) {
+ case coro::ABI::Switch:
+ llvm_unreachable("can't allocate memory in coro switch-lowering");
+
+ case coro::ABI::Retcon:
+ case coro::ABI::RetconOnce: {
+ auto Dealloc = RetconLowering.Dealloc;
+ Ptr = Builder.CreateBitCast(Ptr,
+ Dealloc->getFunctionType()->getParamType(0));
+ auto *Call = Builder.CreateCall(Dealloc, Ptr);
+ propagateCallAttrsFromCallee(Call, Dealloc);
+ addCallToCallGraph(CG, Call, Dealloc);
+ return;
+ }
+ }
+ llvm_unreachable("Unknown coro::ABI enum");
+}
+
+LLVM_ATTRIBUTE_NORETURN
+static void fail(const Instruction *I, const char *Reason, Value *V) {
+#ifndef NDEBUG
+ I->dump();
+ if (V) {
+ errs() << " Value: ";
+ V->printAsOperand(llvm::errs());
+ errs() << '\n';
+ }
+#endif
+ report_fatal_error(Reason);
+}
+
+/// Check that the given value is a well-formed prototype for the
+/// llvm.coro.id.retcon.* intrinsics.
+static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
+ auto F = dyn_cast<Function>(V->stripPointerCasts());
+ if (!F)
+ fail(I, "llvm.coro.id.retcon.* prototype not a Function", V);
+
+ auto FT = F->getFunctionType();
+
+ if (isa<CoroIdRetconInst>(I)) {
+ bool ResultOkay;
+ if (FT->getReturnType()->isPointerTy()) {
+ ResultOkay = true;
+ } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) {
+ ResultOkay = (!SRetTy->isOpaque() &&
+ SRetTy->getNumElements() > 0 &&
+ SRetTy->getElementType(0)->isPointerTy());
+ } else {
+ ResultOkay = false;
+ }
+ if (!ResultOkay)
+ fail(I, "llvm.coro.id.retcon prototype must return pointer as first "
+ "result", F);
+
+ if (FT->getReturnType() !=
+ I->getFunction()->getFunctionType()->getReturnType())
+ fail(I, "llvm.coro.id.retcon prototype return type must be same as"
+ "current function return type", F);
+ } else {
+ // No meaningful validation to do here for llvm.coro.id.unique.once.
+ }
+
+ if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy())
+ fail(I, "llvm.coro.id.retcon.* prototype must take pointer as "
+ "its first parameter", F);
+}
+
+/// Check that the given value is a well-formed allocator.
+static void checkWFAlloc(const Instruction *I, Value *V) {
+ auto F = dyn_cast<Function>(V->stripPointerCasts());
+ if (!F)
+ fail(I, "llvm.coro.* allocator not a Function", V);
+
+ auto FT = F->getFunctionType();
+ if (!FT->getReturnType()->isPointerTy())
+ fail(I, "llvm.coro.* allocator must return a pointer", F);
+
+ if (FT->getNumParams() != 1 ||
+ !FT->getParamType(0)->isIntegerTy())
+ fail(I, "llvm.coro.* allocator must take integer as only param", F);
+}
+
+/// Check that the given value is a well-formed deallocator.
+static void checkWFDealloc(const Instruction *I, Value *V) {
+ auto F = dyn_cast<Function>(V->stripPointerCasts());
+ if (!F)
+ fail(I, "llvm.coro.* deallocator not a Function", V);
+
+ auto FT = F->getFunctionType();
+ if (!FT->getReturnType()->isVoidTy())
+ fail(I, "llvm.coro.* deallocator must return void", F);
+
+ if (FT->getNumParams() != 1 ||
+ !FT->getParamType(0)->isPointerTy())
+ fail(I, "llvm.coro.* deallocator must take pointer as only param", F);
+}
+
+static void checkConstantInt(const Instruction *I, Value *V,
+ const char *Reason) {
+ if (!isa<ConstantInt>(V)) {
+ fail(I, Reason, V);
+ }
+}
+
+void AnyCoroIdRetconInst::checkWellFormed() const {
+ checkConstantInt(this, getArgOperand(SizeArg),
+ "size argument to coro.id.retcon.* must be constant");
+ checkConstantInt(this, getArgOperand(AlignArg),
+ "alignment argument to coro.id.retcon.* must be constant");
+ checkWFRetconPrototype(this, getArgOperand(PrototypeArg));
+ checkWFAlloc(this, getArgOperand(AllocArg));
+ checkWFDealloc(this, getArgOperand(DeallocArg));
+}
+
+void LLVMAddCoroEarlyPass(LLVMPassManagerRef PM) {
+ unwrap(PM)->add(createCoroEarlyPass());
+}
+
+void LLVMAddCoroSplitPass(LLVMPassManagerRef PM) {
+ unwrap(PM)->add(createCoroSplitPass());
+}
+
+void LLVMAddCoroElidePass(LLVMPassManagerRef PM) {
+ unwrap(PM)->add(createCoroElidePass());
+}
+
+void LLVMAddCoroCleanupPass(LLVMPassManagerRef PM) {
+ unwrap(PM)->add(createCoroCleanupPass());
+}