diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2017-04-16 16:02:28 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2017-04-16 16:02:28 +0000 |
commit | 7442d6faa2719e4e7d33a7021c406c5a4facd74d (patch) | |
tree | c72b9241553fc9966179aba84f90f17bfa9235c3 /lib/Sema/SemaCoroutine.cpp | |
parent | b52119637f743680a99710ce5fdb6646da2772af (diff) |
Diffstat (limited to 'lib/Sema/SemaCoroutine.cpp')
-rw-r--r-- | lib/Sema/SemaCoroutine.cpp | 843 |
1 files changed, 581 insertions, 262 deletions
diff --git a/lib/Sema/SemaCoroutine.cpp b/lib/Sema/SemaCoroutine.cpp index 9814b4a84f29b..4a55e51495a81 100644 --- a/lib/Sema/SemaCoroutine.cpp +++ b/lib/Sema/SemaCoroutine.cpp @@ -11,31 +11,46 @@ // //===----------------------------------------------------------------------===// -#include "clang/Sema/SemaInternal.h" +#include "CoroutineStmtBuilder.h" #include "clang/AST/Decl.h" #include "clang/AST/ExprCXX.h" #include "clang/AST/StmtCXX.h" #include "clang/Lex/Preprocessor.h" #include "clang/Sema/Initialization.h" #include "clang/Sema/Overload.h" +#include "clang/Sema/SemaInternal.h" + using namespace clang; using namespace sema; +static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD, + SourceLocation Loc) { + DeclarationName DN = S.PP.getIdentifierInfo(Name); + LookupResult LR(S, DN, Loc, Sema::LookupMemberName); + // Suppress diagnostics when a private member is selected. The same warnings + // will be produced again when building the call. + LR.suppressDiagnostics(); + return S.LookupQualifiedName(LR, RD); +} + /// Look up the std::coroutine_traits<...>::promise_type for the given /// function type. static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType, - SourceLocation Loc) { + SourceLocation KwLoc, + SourceLocation FuncLoc) { // FIXME: Cache std::coroutine_traits once we've found it. NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); if (!StdExp) { - S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); + S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found) + << "std::experimental::coroutine_traits"; return QualType(); } LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_traits"), - Loc, Sema::LookupOrdinaryName); + FuncLoc, Sema::LookupOrdinaryName); if (!S.LookupQualifiedName(Result, StdExp)) { - S.Diag(Loc, diag::err_implied_std_coroutine_traits_not_found); + S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found) + << "std::experimental::coroutine_traits"; return QualType(); } @@ -49,56 +64,107 @@ static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType, } // Form template argument list for coroutine_traits<R, P1, P2, ...>. - TemplateArgumentListInfo Args(Loc, Loc); + TemplateArgumentListInfo Args(KwLoc, KwLoc); Args.addArgument(TemplateArgumentLoc( TemplateArgument(FnType->getReturnType()), - S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), Loc))); + S.Context.getTrivialTypeSourceInfo(FnType->getReturnType(), KwLoc))); // FIXME: If the function is a non-static member function, add the type // of the implicit object parameter before the formal parameters. for (QualType T : FnType->getParamTypes()) Args.addArgument(TemplateArgumentLoc( - TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, Loc))); + TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc))); // Build the template-id. QualType CoroTrait = - S.CheckTemplateIdType(TemplateName(CoroTraits), Loc, Args); + S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args); if (CoroTrait.isNull()) return QualType(); - if (S.RequireCompleteType(Loc, CoroTrait, - diag::err_coroutine_traits_missing_specialization)) + if (S.RequireCompleteType(KwLoc, CoroTrait, + diag::err_coroutine_type_missing_specialization)) return QualType(); - CXXRecordDecl *RD = CoroTrait->getAsCXXRecordDecl(); + auto *RD = CoroTrait->getAsCXXRecordDecl(); assert(RD && "specialization of class template is not a class?"); // Look up the ::promise_type member. - LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), Loc, + LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc, Sema::LookupOrdinaryName); S.LookupQualifiedName(R, RD); auto *Promise = R.getAsSingle<TypeDecl>(); if (!Promise) { - S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_found) + S.Diag(FuncLoc, + diag::err_implied_std_coroutine_traits_promise_type_not_found) << RD; return QualType(); } - // The promise type is required to be a class type. QualType PromiseType = S.Context.getTypeDeclType(Promise); - if (!PromiseType->getAsCXXRecordDecl()) { - // Use the fully-qualified name of the type. + + auto buildElaboratedType = [&]() { auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp); NNS = NestedNameSpecifier::Create(S.Context, NNS, false, CoroTrait.getTypePtr()); - PromiseType = S.Context.getElaboratedType(ETK_None, NNS, PromiseType); + return S.Context.getElaboratedType(ETK_None, NNS, PromiseType); + }; - S.Diag(Loc, diag::err_implied_std_coroutine_traits_promise_type_not_class) - << PromiseType; + if (!PromiseType->getAsCXXRecordDecl()) { + S.Diag(FuncLoc, + diag::err_implied_std_coroutine_traits_promise_type_not_class) + << buildElaboratedType(); return QualType(); } + if (S.RequireCompleteType(FuncLoc, buildElaboratedType(), + diag::err_coroutine_promise_type_incomplete)) + return QualType(); return PromiseType; } +/// Look up the std::coroutine_traits<...>::promise_type for the given +/// function type. +static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType, + SourceLocation Loc) { + if (PromiseType.isNull()) + return QualType(); + + NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace(); + assert(StdExp && "Should already be diagnosed"); + + LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"), + Loc, Sema::LookupOrdinaryName); + if (!S.LookupQualifiedName(Result, StdExp)) { + S.Diag(Loc, diag::err_implied_coroutine_type_not_found) + << "std::experimental::coroutine_handle"; + return QualType(); + } + + ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>(); + if (!CoroHandle) { + Result.suppressDiagnostics(); + // We found something weird. Complain about the first thing we found. + NamedDecl *Found = *Result.begin(); + S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle); + return QualType(); + } + + // Form template argument list for coroutine_handle<Promise>. + TemplateArgumentListInfo Args(Loc, Loc); + Args.addArgument(TemplateArgumentLoc( + TemplateArgument(PromiseType), + S.Context.getTrivialTypeSourceInfo(PromiseType, Loc))); + + // Build the template-id. + QualType CoroHandleType = + S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args); + if (CoroHandleType.isNull()) + return QualType(); + if (S.RequireCompleteType(Loc, CoroHandleType, + diag::err_coroutine_type_missing_specialization)) + return QualType(); + + return CoroHandleType; +} + static bool isValidCoroutineContext(Sema &S, SourceLocation Loc, StringRef Keyword) { // 'co_await' and 'co_yield' are not permitted in unevaluated operands. @@ -160,41 +226,48 @@ static bool isValidCoroutineContext(Sema &S, SourceLocation Loc, return !Diagnosed; } -/// Check that this is a context in which a coroutine suspension can appear. -static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, - StringRef Keyword) { - if (!isValidCoroutineContext(S, Loc, Keyword)) - return nullptr; - - assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope"); - auto *FD = cast<FunctionDecl>(S.CurContext); - auto *ScopeInfo = S.getCurFunction(); - assert(ScopeInfo && "missing function scope for function"); +static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S, + SourceLocation Loc) { + DeclarationName OpName = + SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait); + LookupResult Operators(SemaRef, OpName, SourceLocation(), + Sema::LookupOperatorName); + SemaRef.LookupName(Operators, S); + + assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous"); + const auto &Functions = Operators.asUnresolvedSet(); + bool IsOverloaded = + Functions.size() > 1 || + (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin())); + Expr *CoawaitOp = UnresolvedLookupExpr::Create( + SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(), + DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded, + Functions.begin(), Functions.end()); + assert(CoawaitOp); + return CoawaitOp; +} - // If we don't have a promise variable, build one now. - if (!ScopeInfo->CoroutinePromise) { - QualType T = FD->getType()->isDependentType() - ? S.Context.DependentTy - : lookupPromiseType( - S, FD->getType()->castAs<FunctionProtoType>(), Loc); - if (T.isNull()) - return nullptr; - - // Create and default-initialize the promise. - ScopeInfo->CoroutinePromise = - VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(), - &S.PP.getIdentifierTable().get("__promise"), T, - S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None); - S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise); - if (!ScopeInfo->CoroutinePromise->isInvalidDecl()) - S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise); - } +/// Build a call to 'operator co_await' if there is a suitable operator for +/// the given expression. +static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc, + Expr *E, + UnresolvedLookupExpr *Lookup) { + UnresolvedSet<16> Functions; + Functions.append(Lookup->decls_begin(), Lookup->decls_end()); + return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); +} - return ScopeInfo; +static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, + SourceLocation Loc, Expr *E) { + ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc); + if (R.isInvalid()) + return ExprError(); + return buildOperatorCoawaitCall(SemaRef, Loc, E, + cast<UnresolvedLookupExpr>(R.get())); } static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id, - MutableArrayRef<Expr *> CallArgs) { + MultiExprArg CallArgs) { StringRef Name = S.Context.BuiltinInfo.getName(Id); LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName); S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true); @@ -213,24 +286,41 @@ static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id, return Call.get(); } -/// Build a call to 'operator co_await' if there is a suitable operator for -/// the given expression. -static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S, - SourceLocation Loc, Expr *E) { - UnresolvedSet<16> Functions; - SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(), - Functions); - return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E); +static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType, + SourceLocation Loc) { + QualType CoroHandleType = lookupCoroutineHandleType(S, PromiseType, Loc); + if (CoroHandleType.isNull()) + return ExprError(); + + DeclContext *LookupCtx = S.computeDeclContext(CoroHandleType); + LookupResult Found(S, &S.PP.getIdentifierTable().get("from_address"), Loc, + Sema::LookupOrdinaryName); + if (!S.LookupQualifiedName(Found, LookupCtx)) { + S.Diag(Loc, diag::err_coroutine_handle_missing_member) + << "from_address"; + return ExprError(); + } + + Expr *FramePtr = + buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {}); + + CXXScopeSpec SS; + ExprResult FromAddr = + S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false); + if (FromAddr.isInvalid()) + return ExprError(); + + return S.ActOnCallExpr(nullptr, FromAddr.get(), Loc, FramePtr, Loc); } struct ReadySuspendResumeResult { - bool IsInvalid; Expr *Results[3]; + OpaqueValueExpr *OpaqueValue; + bool IsInvalid; }; static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, - StringRef Name, - MutableArrayRef<Expr *> Args) { + StringRef Name, MultiExprArg Args) { DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc); // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&. @@ -247,18 +337,23 @@ static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc, /// Build calls to await_ready, await_suspend, and await_resume for a co_await /// expression. -static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc, - Expr *E) { +static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise, + SourceLocation Loc, Expr *E) { + OpaqueValueExpr *Operand = new (S.Context) + OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E); + // Assume invalid until we see otherwise. - ReadySuspendResumeResult Calls = {true, {}}; + ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true}; + + ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc); + if (CoroHandleRes.isInvalid()) + return Calls; + Expr *CoroHandle = CoroHandleRes.get(); const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"}; + MultiExprArg Args[] = {None, CoroHandle, None}; for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) { - Expr *Operand = new (S.Context) OpaqueValueExpr( - Loc, E->getType(), VK_LValue, E->getObjectKind(), E); - - // FIXME: Pass coroutine handle to await_suspend. - ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], None); + ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]); if (Result.isInvalid()) return Calls; Calls.Results[I] = Result.get(); @@ -268,26 +363,177 @@ static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, SourceLocation Loc, return Calls; } +static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise, + SourceLocation Loc, StringRef Name, + MultiExprArg Args) { + + // Form a reference to the promise. + ExprResult PromiseRef = S.BuildDeclRefExpr( + Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); + if (PromiseRef.isInvalid()) + return ExprError(); + + // Call 'yield_value', passing in E. + return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); +} + +VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) { + assert(isa<FunctionDecl>(CurContext) && "not in a function scope"); + auto *FD = cast<FunctionDecl>(CurContext); + + QualType T = + FD->getType()->isDependentType() + ? Context.DependentTy + : lookupPromiseType(*this, FD->getType()->castAs<FunctionProtoType>(), + Loc, FD->getLocation()); + if (T.isNull()) + return nullptr; + + auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(), + &PP.getIdentifierTable().get("__promise"), T, + Context.getTrivialTypeSourceInfo(T, Loc), SC_None); + CheckVariableDeclarationType(VD); + if (VD->isInvalidDecl()) + return nullptr; + ActOnUninitializedDecl(VD); + assert(!VD->isInvalidDecl()); + return VD; +} + +/// Check that this is a context in which a coroutine suspension can appear. +static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc, + StringRef Keyword, + bool IsImplicit = false) { + if (!isValidCoroutineContext(S, Loc, Keyword)) + return nullptr; + + assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope"); + + auto *ScopeInfo = S.getCurFunction(); + assert(ScopeInfo && "missing function scope for function"); + + if (ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && !IsImplicit) + ScopeInfo->setFirstCoroutineStmt(Loc, Keyword); + + if (ScopeInfo->CoroutinePromise) + return ScopeInfo; + + ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc); + if (!ScopeInfo->CoroutinePromise) + return nullptr; + + return ScopeInfo; +} + +static bool actOnCoroutineBodyStart(Sema &S, Scope *SC, SourceLocation KWLoc, + StringRef Keyword) { + if (!checkCoroutineContext(S, KWLoc, Keyword)) + return false; + auto *ScopeInfo = S.getCurFunction(); + assert(ScopeInfo->CoroutinePromise); + + // If we have existing coroutine statements then we have already built + // the initial and final suspend points. + if (!ScopeInfo->NeedsCoroutineSuspends) + return true; + + ScopeInfo->setNeedsCoroutineSuspends(false); + + auto *Fn = cast<FunctionDecl>(S.CurContext); + SourceLocation Loc = Fn->getLocation(); + // Build the initial suspend point + auto buildSuspends = [&](StringRef Name) mutable -> StmtResult { + ExprResult Suspend = + buildPromiseCall(S, ScopeInfo->CoroutinePromise, Loc, Name, None); + if (Suspend.isInvalid()) + return StmtError(); + Suspend = buildOperatorCoawaitCall(S, SC, Loc, Suspend.get()); + if (Suspend.isInvalid()) + return StmtError(); + Suspend = S.BuildResolvedCoawaitExpr(Loc, Suspend.get(), + /*IsImplicit*/ true); + Suspend = S.ActOnFinishFullExpr(Suspend.get()); + if (Suspend.isInvalid()) { + S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required) + << ((Name == "initial_suspend") ? 0 : 1); + S.Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword; + return StmtError(); + } + return cast<Stmt>(Suspend.get()); + }; + + StmtResult InitSuspend = buildSuspends("initial_suspend"); + if (InitSuspend.isInvalid()) + return true; + + StmtResult FinalSuspend = buildSuspends("final_suspend"); + if (FinalSuspend.isInvalid()) + return true; + + ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get()); + + return true; +} + ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); - if (!Coroutine) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_await")) { CorrectDelayedTyposInExpr(E); return ExprError(); } + if (E->getType()->isPlaceholderType()) { ExprResult R = CheckPlaceholderExpr(E); if (R.isInvalid()) return ExprError(); E = R.get(); } + ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc); + if (Lookup.isInvalid()) + return ExprError(); + return BuildUnresolvedCoawaitExpr(Loc, E, + cast<UnresolvedLookupExpr>(Lookup.get())); +} + +ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E, + UnresolvedLookupExpr *Lookup) { + auto *FSI = checkCoroutineContext(*this, Loc, "co_await"); + if (!FSI) + return ExprError(); - ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E); + if (E->getType()->isPlaceholderType()) { + ExprResult R = CheckPlaceholderExpr(E); + if (R.isInvalid()) + return ExprError(); + E = R.get(); + } + + auto *Promise = FSI->CoroutinePromise; + if (Promise->getType()->isDependentType()) { + Expr *Res = + new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup); + return Res; + } + + auto *RD = Promise->getType()->getAsCXXRecordDecl(); + if (lookupMember(*this, "await_transform", RD, Loc)) { + ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E); + if (R.isInvalid()) { + Diag(Loc, + diag::note_coroutine_promise_implicit_await_transform_required_here) + << E->getSourceRange(); + return ExprError(); + } + E = R.get(); + } + ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup); if (Awaitable.isInvalid()) return ExprError(); - return BuildCoawaitExpr(Loc, Awaitable.get()); + return BuildResolvedCoawaitExpr(Loc, Awaitable.get()); } -ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await"); + +ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E, + bool IsImplicit) { + auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit); if (!Coroutine) return ExprError(); @@ -298,8 +544,8 @@ ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { } if (E->getType()->isDependentType()) { - Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E); - Coroutine->CoroutineStmts.push_back(Res); + Expr *Res = new (Context) + CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit); return Res; } @@ -309,42 +555,27 @@ ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) { E = CreateMaterializeTemporaryExpr(E->getType(), E, true); // Build the await_ready, await_suspend, await_resume calls. - ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E); + ReadySuspendResumeResult RSS = + buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E); if (RSS.IsInvalid) return ExprError(); - Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], - RSS.Results[2]); - Coroutine->CoroutineStmts.push_back(Res); - return Res; -} - -static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine, - SourceLocation Loc, StringRef Name, - MutableArrayRef<Expr *> Args) { - assert(Coroutine->CoroutinePromise && "no promise for coroutine"); - - // Form a reference to the promise. - auto *Promise = Coroutine->CoroutinePromise; - ExprResult PromiseRef = S.BuildDeclRefExpr( - Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc); - if (PromiseRef.isInvalid()) - return ExprError(); + Expr *Res = + new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1], + RSS.Results[2], RSS.OpaqueValue, IsImplicit); - // Call 'yield_value', passing in E. - return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args); + return Res; } ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield"); - if (!Coroutine) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_yield")) { CorrectDelayedTyposInExpr(E); return ExprError(); } // Build yield_value call. - ExprResult Awaitable = - buildPromiseCall(*this, Coroutine, Loc, "yield_value", E); + ExprResult Awaitable = buildPromiseCall( + *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E); if (Awaitable.isInvalid()) return ExprError(); @@ -368,7 +599,6 @@ ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) { if (E->getType()->isDependentType()) { Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E); - Coroutine->CoroutineStmts.push_back(Res); return Res; } @@ -378,28 +608,29 @@ ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) { E = CreateMaterializeTemporaryExpr(E->getType(), E, true); // Build the await_ready, await_suspend, await_resume calls. - ReadySuspendResumeResult RSS = buildCoawaitCalls(*this, Loc, E); + ReadySuspendResumeResult RSS = + buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E); if (RSS.IsInvalid) return ExprError(); Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1], - RSS.Results[2]); - Coroutine->CoroutineStmts.push_back(Res); + RSS.Results[2], RSS.OpaqueValue); + return Res; } -StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); - if (!Coroutine) { +StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) { + if (!actOnCoroutineBodyStart(*this, S, Loc, "co_return")) { CorrectDelayedTyposInExpr(E); return StmtError(); } return BuildCoreturnStmt(Loc, E); } -StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { - auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return"); - if (!Coroutine) +StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E, + bool IsImplicit) { + auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit); + if (!FSI) return StmtError(); if (E && E->getType()->isPlaceholderType() && @@ -412,48 +643,20 @@ StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) { // FIXME: If the operand is a reference to a variable that's about to go out // of scope, we should treat the operand as an xvalue for this overload // resolution. + VarDecl *Promise = FSI->CoroutinePromise; ExprResult PC; if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) { - PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E); + PC = buildPromiseCall(*this, Promise, Loc, "return_value", E); } else { E = MakeFullDiscardedValueExpr(E).get(); - PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None); + PC = buildPromiseCall(*this, Promise, Loc, "return_void", None); } if (PC.isInvalid()) return StmtError(); Expr *PCE = ActOnFinishFullExpr(PC.get()).get(); - Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE); - Coroutine->CoroutineStmts.push_back(Res); - return Res; -} - -static ExprResult buildStdCurrentExceptionCall(Sema &S, SourceLocation Loc) { - NamespaceDecl *Std = S.getStdNamespace(); - if (!Std) { - S.Diag(Loc, diag::err_implied_std_current_exception_not_found); - return ExprError(); - } - LookupResult Result(S, &S.PP.getIdentifierTable().get("current_exception"), - Loc, Sema::LookupOrdinaryName); - if (!S.LookupQualifiedName(Result, Std)) { - S.Diag(Loc, diag::err_implied_std_current_exception_not_found); - return ExprError(); - } - - // FIXME The STL is free to provide more than one overload. - FunctionDecl *FD = Result.getAsSingle<FunctionDecl>(); - if (!FD) { - S.Diag(Loc, diag::err_malformed_std_current_exception); - return ExprError(); - } - ExprResult Res = S.BuildDeclRefExpr(FD, FD->getType(), VK_LValue, Loc); - Res = S.ActOnCallExpr(/*Scope*/ nullptr, Res.get(), Loc, None, Loc); - if (Res.isInvalid()) { - S.Diag(Loc, diag::err_malformed_std_current_exception); - return ExprError(); - } + Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit); return Res; } @@ -482,21 +685,170 @@ static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc, return OperatorDelete; } -// Builds allocation and deallocation for the coroutine. Returns false on -// failure. -static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc, - FunctionScopeInfo *Fn, - Expr *&Allocation, - Stmt *&Deallocation) { - TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo(); - QualType PromiseType = TInfo->getType(); - if (PromiseType->isDependentType()) + +void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { + FunctionScopeInfo *Fn = getCurFunction(); + assert(Fn && Fn->CoroutinePromise && "not a coroutine"); + + if (!Body) { + assert(FD->isInvalidDecl() && + "a null body is only allowed for invalid declarations"); + return; + } + + if (isa<CoroutineBodyStmt>(Body)) { + // FIXME(EricWF): Nothing todo. the body is already a transformed coroutine + // body statement. + return; + } + + // Coroutines [stmt.return]p1: + // A return statement shall not appear in a coroutine. + if (Fn->FirstReturnLoc.isValid()) { + assert(Fn->FirstCoroutineStmtLoc.isValid() && + "first coroutine location not set"); + Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); + Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) + << Fn->getFirstCoroutineStmtKeyword(); + } + CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body); + if (Builder.isInvalid() || !Builder.buildStatements()) + return FD->setInvalidDecl(); + + // Build body for the coroutine wrapper statement. + Body = CoroutineBodyStmt::Create(Context, Builder); +} + +CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD, + sema::FunctionScopeInfo &Fn, + Stmt *Body) + : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()), + IsPromiseDependentType( + !Fn.CoroutinePromise || + Fn.CoroutinePromise->getType()->isDependentType()) { + this->Body = Body; + if (!IsPromiseDependentType) { + PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl(); + assert(PromiseRecordDecl && "Type should have already been checked"); + } + this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend(); +} + +bool CoroutineStmtBuilder::buildStatements() { + assert(this->IsValid && "coroutine already invalid"); + this->IsValid = makeReturnObject() && makeParamMoves(); + if (this->IsValid && !IsPromiseDependentType) + buildDependentStatements(); + return this->IsValid; +} + +bool CoroutineStmtBuilder::buildDependentStatements() { + assert(this->IsValid && "coroutine already invalid"); + assert(!this->IsPromiseDependentType && + "coroutine cannot have a dependent promise type"); + this->IsValid = makeOnException() && makeOnFallthrough() && + makeReturnOnAllocFailure() && makeNewAndDeleteExpr(); + return this->IsValid; +} + +bool CoroutineStmtBuilder::makePromiseStmt() { + // Form a declaration statement for the promise declaration, so that AST + // visitors can more easily find it. + StmtResult PromiseStmt = + S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc); + if (PromiseStmt.isInvalid()) + return false; + + this->Promise = PromiseStmt.get(); + return true; +} + +bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() { + if (Fn.hasInvalidCoroutineSuspends()) + return false; + this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first); + this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second); + return true; +} + +static bool diagReturnOnAllocFailure(Sema &S, Expr *E, + CXXRecordDecl *PromiseRecordDecl, + FunctionScopeInfo &Fn) { + auto Loc = E->getExprLoc(); + if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) { + auto *Decl = DeclRef->getDecl(); + if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) { + if (Method->isStatic()) + return true; + else + Loc = Decl->getLocation(); + } + } + + S.Diag( + Loc, + diag::err_coroutine_promise_get_return_object_on_allocation_failure) + << PromiseRecordDecl; + S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here) + << Fn.getFirstCoroutineStmtKeyword(); + return false; +} + +bool CoroutineStmtBuilder::makeReturnOnAllocFailure() { + assert(!IsPromiseDependentType && + "cannot make statement while the promise type is dependent"); + + // [dcl.fct.def.coroutine]/8 + // The unqualified-id get_return_object_on_allocation_failure is looked up in + // the scope of class P by class member access lookup (3.4.5). ... + // If an allocation function returns nullptr, ... the coroutine return value + // is obtained by a call to ... get_return_object_on_allocation_failure(). + + DeclarationName DN = + S.PP.getIdentifierInfo("get_return_object_on_allocation_failure"); + LookupResult Found(S, DN, Loc, Sema::LookupMemberName); + if (!S.LookupQualifiedName(Found, PromiseRecordDecl)) return true; + CXXScopeSpec SS; + ExprResult DeclNameExpr = + S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false); + if (DeclNameExpr.isInvalid()) + return false; + + if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn)) + return false; + + ExprResult ReturnObjectOnAllocationFailure = + S.ActOnCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc); + if (ReturnObjectOnAllocationFailure.isInvalid()) + return false; + + // FIXME: ActOnReturnStmt expects a scope that is inside of the function, due + // to CheckJumpOutOfSEHFinally(*this, ReturnLoc, *CurScope->getFnParent()); + // S.getCurScope()->getFnParent() == nullptr at ActOnFinishFunctionBody when + // CoroutineBodyStmt is built. Figure it out and fix it. + // Use BuildReturnStmt here to unbreak sanitized tests. (Gor:3/27/2017) + StmtResult ReturnStmt = + S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get()); + if (ReturnStmt.isInvalid()) + return false; + + this->ReturnStmtOnAllocFailure = ReturnStmt.get(); + return true; +} + +bool CoroutineStmtBuilder::makeNewAndDeleteExpr() { + // Form and check allocation and deallocation calls. + assert(!IsPromiseDependentType && + "cannot make statement while the promise type is dependent"); + QualType PromiseType = Fn.CoroutinePromise->getType(); + if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type)) return false; - // FIXME: Add support for get_return_object_on_allocation failure. + // FIXME: Add nothrow_t placement arg for global alloc + // if ReturnStmtOnAllocFailure != nullptr. // FIXME: Add support for stateful allocators. FunctionDecl *OperatorNew = nullptr; @@ -532,8 +884,6 @@ static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc, if (NewExpr.isInvalid()) return false; - Allocation = NewExpr.get(); - // Make delete call. QualType OpDeleteQualType = OperatorDelete->getType(); @@ -559,138 +909,107 @@ static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc, if (DeleteExpr.isInvalid()) return false; - Deallocation = DeleteExpr.get(); + this->Allocate = NewExpr.get(); + this->Deallocate = DeleteExpr.get(); return true; } -void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) { - FunctionScopeInfo *Fn = getCurFunction(); - assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine"); +bool CoroutineStmtBuilder::makeOnFallthrough() { + assert(!IsPromiseDependentType && + "cannot make statement while the promise type is dependent"); - // Coroutines [stmt.return]p1: - // A return statement shall not appear in a coroutine. - if (Fn->FirstReturnLoc.isValid()) { - Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine); - auto *First = Fn->CoroutineStmts[0]; - Diag(First->getLocStart(), diag::note_declared_coroutine_here) - << (isa<CoawaitExpr>(First) ? 0 : - isa<CoyieldExpr>(First) ? 1 : 2); - } + // [dcl.fct.def.coroutine]/4 + // The unqualified-ids 'return_void' and 'return_value' are looked up in + // the scope of class P. If both are found, the program is ill-formed. + const bool HasRVoid = lookupMember(S, "return_void", PromiseRecordDecl, Loc); + const bool HasRValue = lookupMember(S, "return_value", PromiseRecordDecl, Loc); - SourceLocation Loc = FD->getLocation(); + StmtResult Fallthrough; + if (HasRVoid && HasRValue) { + // FIXME Improve this diagnostic + S.Diag(FD.getLocation(), diag::err_coroutine_promise_return_ill_formed) + << PromiseRecordDecl; + return false; + } else if (HasRVoid) { + // If the unqualified-id return_void is found, flowing off the end of a + // coroutine is equivalent to a co_return with no operand. Otherwise, + // flowing off the end of a coroutine results in undefined behavior. + Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr, + /*IsImplicit*/false); + Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get()); + if (Fallthrough.isInvalid()) + return false; + } - // Form a declaration statement for the promise declaration, so that AST - // visitors can more easily find it. - StmtResult PromiseStmt = - ActOnDeclStmt(ConvertDeclToDeclGroup(Fn->CoroutinePromise), Loc, Loc); - if (PromiseStmt.isInvalid()) - return FD->setInvalidDecl(); + this->OnFallthrough = Fallthrough.get(); + return true; +} - // Form and check implicit 'co_await p.initial_suspend();' statement. - ExprResult InitialSuspend = - buildPromiseCall(*this, Fn, Loc, "initial_suspend", None); - // FIXME: Support operator co_await here. - if (!InitialSuspend.isInvalid()) - InitialSuspend = BuildCoawaitExpr(Loc, InitialSuspend.get()); - InitialSuspend = ActOnFinishFullExpr(InitialSuspend.get()); - if (InitialSuspend.isInvalid()) - return FD->setInvalidDecl(); +bool CoroutineStmtBuilder::makeOnException() { + // Try to form 'p.unhandled_exception();' + assert(!IsPromiseDependentType && + "cannot make statement while the promise type is dependent"); + + const bool RequireUnhandledException = S.getLangOpts().CXXExceptions; + + if (!lookupMember(S, "unhandled_exception", PromiseRecordDecl, Loc)) { + auto DiagID = + RequireUnhandledException + ? diag::err_coroutine_promise_unhandled_exception_required + : diag:: + warn_coroutine_promise_unhandled_exception_required_with_exceptions; + S.Diag(Loc, DiagID) << PromiseRecordDecl; + return !RequireUnhandledException; + } - // Form and check implicit 'co_await p.final_suspend();' statement. - ExprResult FinalSuspend = - buildPromiseCall(*this, Fn, Loc, "final_suspend", None); - // FIXME: Support operator co_await here. - if (!FinalSuspend.isInvalid()) - FinalSuspend = BuildCoawaitExpr(Loc, FinalSuspend.get()); - FinalSuspend = ActOnFinishFullExpr(FinalSuspend.get()); - if (FinalSuspend.isInvalid()) - return FD->setInvalidDecl(); + // If exceptions are disabled, don't try to build OnException. + if (!S.getLangOpts().CXXExceptions) + return true; - // Form and check allocation and deallocation calls. - Expr *Allocation = nullptr; - Stmt *Deallocation = nullptr; - if (!buildAllocationAndDeallocation(*this, Loc, Fn, Allocation, Deallocation)) - return FD->setInvalidDecl(); + ExprResult UnhandledException = buildPromiseCall(S, Fn.CoroutinePromise, Loc, + "unhandled_exception", None); + UnhandledException = S.ActOnFinishFullExpr(UnhandledException.get(), Loc); + if (UnhandledException.isInvalid()) + return false; - // control flowing off the end of the coroutine. - // Also try to form 'p.set_exception(std::current_exception());' to handle - // uncaught exceptions. - ExprResult SetException; - StmtResult Fallthrough; - if (Fn->CoroutinePromise && - !Fn->CoroutinePromise->getType()->isDependentType()) { - CXXRecordDecl *RD = Fn->CoroutinePromise->getType()->getAsCXXRecordDecl(); - assert(RD && "Type should have already been checked"); - // [dcl.fct.def.coroutine]/4 - // The unqualified-ids 'return_void' and 'return_value' are looked up in - // the scope of class P. If both are found, the program is ill-formed. - DeclarationName RVoidDN = PP.getIdentifierInfo("return_void"); - LookupResult RVoidResult(*this, RVoidDN, Loc, Sema::LookupMemberName); - const bool HasRVoid = LookupQualifiedName(RVoidResult, RD); - - DeclarationName RValueDN = PP.getIdentifierInfo("return_value"); - LookupResult RValueResult(*this, RValueDN, Loc, Sema::LookupMemberName); - const bool HasRValue = LookupQualifiedName(RValueResult, RD); - - if (HasRVoid && HasRValue) { - // FIXME Improve this diagnostic - Diag(FD->getLocation(), diag::err_coroutine_promise_return_ill_formed) - << RD; - return FD->setInvalidDecl(); - } else if (HasRVoid) { - // If the unqualified-id return_void is found, flowing off the end of a - // coroutine is equivalent to a co_return with no operand. Otherwise, - // flowing off the end of a coroutine results in undefined behavior. - Fallthrough = BuildCoreturnStmt(FD->getLocation(), nullptr); - Fallthrough = ActOnFinishFullStmt(Fallthrough.get()); - if (Fallthrough.isInvalid()) - return FD->setInvalidDecl(); - } + this->OnException = UnhandledException.get(); + return true; +} - // [dcl.fct.def.coroutine]/3 - // The unqualified-id set_exception is found in the scope of P by class - // member access lookup (3.4.5). - DeclarationName SetExDN = PP.getIdentifierInfo("set_exception"); - LookupResult SetExResult(*this, SetExDN, Loc, Sema::LookupMemberName); - if (LookupQualifiedName(SetExResult, RD)) { - // Form the call 'p.set_exception(std::current_exception())' - SetException = buildStdCurrentExceptionCall(*this, Loc); - if (SetException.isInvalid()) - return FD->setInvalidDecl(); - Expr *E = SetException.get(); - SetException = buildPromiseCall(*this, Fn, Loc, "set_exception", E); - SetException = ActOnFinishFullExpr(SetException.get(), Loc); - if (SetException.isInvalid()) - return FD->setInvalidDecl(); - } - } +bool CoroutineStmtBuilder::makeReturnObject() { // Build implicit 'p.get_return_object()' expression and form initialization // of return type from it. ExprResult ReturnObject = - buildPromiseCall(*this, Fn, Loc, "get_return_object", None); + buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None); if (ReturnObject.isInvalid()) - return FD->setInvalidDecl(); - QualType RetType = FD->getReturnType(); + return false; + QualType RetType = FD.getReturnType(); if (!RetType->isDependentType()) { InitializedEntity Entity = InitializedEntity::InitializeResult(Loc, RetType, false); - ReturnObject = PerformMoveOrCopyInitialization(Entity, nullptr, RetType, + ReturnObject = S.PerformMoveOrCopyInitialization(Entity, nullptr, RetType, ReturnObject.get()); if (ReturnObject.isInvalid()) - return FD->setInvalidDecl(); + return false; } - ReturnObject = ActOnFinishFullExpr(ReturnObject.get(), Loc); + ReturnObject = S.ActOnFinishFullExpr(ReturnObject.get(), Loc); if (ReturnObject.isInvalid()) - return FD->setInvalidDecl(); + return false; + this->ReturnValue = ReturnObject.get(); + return true; +} + +bool CoroutineStmtBuilder::makeParamMoves() { // FIXME: Perform move-initialization of parameters into frame-local copies. - SmallVector<Expr*, 16> ParamMoves; + return true; +} - // Build body for the coroutine wrapper statement. - Body = new (Context) CoroutineBodyStmt( - Body, PromiseStmt.get(), InitialSuspend.get(), FinalSuspend.get(), - SetException.get(), Fallthrough.get(), Allocation, Deallocation, - ReturnObject.get(), ParamMoves); +StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) { + CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args); + if (!Res) + return StmtError(); + return Res; } |