diff options
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc/RPCUtils.h')
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/RPCUtils.h | 787 |
1 files changed, 538 insertions, 249 deletions
diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 37e2e66e5af4..6212f64ff319 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -26,27 +26,115 @@ #include "llvm/ExecutionEngine/Orc/OrcError.h" #include "llvm/ExecutionEngine/Orc/RPCSerialization.h" -#ifdef _MSC_VER -// concrt.h depends on eh.h for __uncaught_exception declaration -// even if we disable exceptions. -#include <eh.h> - -// Disable warnings from ppltasks.h transitively included by <future>. -#pragma warning(push) -#pragma warning(disable : 4530) -#pragma warning(disable : 4062) -#endif - #include <future> -#ifdef _MSC_VER -#pragma warning(pop) -#endif - namespace llvm { namespace orc { namespace rpc { +/// Base class of all fatal RPC errors (those that necessarily result in the +/// termination of the RPC session). +class RPCFatalError : public ErrorInfo<RPCFatalError> { +public: + static char ID; +}; + +/// RPCConnectionClosed is returned from RPC operations if the RPC connection +/// has already been closed due to either an error or graceful disconnection. +class ConnectionClosed : public ErrorInfo<ConnectionClosed> { +public: + static char ID; + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; +}; + +/// BadFunctionCall is returned from handleOne when the remote makes a call with +/// an unrecognized function id. +/// +/// This error is fatal because Orc RPC needs to know how to parse a function +/// call to know where the next call starts, and if it doesn't recognize the +/// function id it cannot parse the call. +template <typename FnIdT, typename SeqNoT> +class BadFunctionCall + : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> { +public: + static char ID; + + BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) + : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnexpectedRPCCall); + } + + void log(raw_ostream &OS) const override { + OS << "Call to invalid RPC function id '" << FnId << "' with " + "sequence number " << SeqNo; + } + +private: + FnIdT FnId; + SeqNoT SeqNo; +}; + +template <typename FnIdT, typename SeqNoT> +char BadFunctionCall<FnIdT, SeqNoT>::ID = 0; + +/// InvalidSequenceNumberForResponse is returned from handleOne when a response +/// call arrives with a sequence number that doesn't correspond to any in-flight +/// function call. +/// +/// This error is fatal because Orc RPC needs to know how to parse the rest of +/// the response call to know where the next call starts, and if it doesn't have +/// a result parser for this sequence number it can't do that. +template <typename SeqNoT> +class InvalidSequenceNumberForResponse + : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> { +public: + static char ID; + + InvalidSequenceNumberForResponse(SeqNoT SeqNo) + : SeqNo(std::move(SeqNo)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnexpectedRPCCall); + }; + + void log(raw_ostream &OS) const override { + OS << "Response has unknown sequence number " << SeqNo; + } +private: + SeqNoT SeqNo; +}; + +template <typename SeqNoT> +char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0; + +/// This non-fatal error will be passed to asynchronous result handlers in place +/// of a result if the connection goes down before a result returns, or if the +/// function to be called cannot be negotiated with the remote. +class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> { +public: + static char ID; + + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; +}; + +/// This error is returned if the remote does not have a handler installed for +/// the given RPC function. +class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> { +public: + static char ID; + + CouldNotNegotiate(std::string Signature); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const std::string &getSignature() const { return Signature; } +private: + std::string Signature; +}; + template <typename DerivedFunc, typename FnT> class Function; // RPC Function class. @@ -82,16 +170,6 @@ std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex; template <typename DerivedFunc, typename RetT, typename... ArgTs> std::string Function<DerivedFunc, RetT(ArgTs...)>::Name; -/// Provides a typedef for a tuple containing the decayed argument types. -template <typename T> class FunctionArgsTuple; - -template <typename RetT, typename... ArgTs> -class FunctionArgsTuple<RetT(ArgTs...)> { -public: - using Type = std::tuple<typename std::decay< - typename std::remove_reference<ArgTs>::type>::type...>; -}; - /// Allocates RPC function ids during autonegotiation. /// Specializations of this class must provide four members: /// @@ -196,6 +274,16 @@ public: #endif // _MSC_VER +/// Provides a typedef for a tuple containing the decayed argument types. +template <typename T> class FunctionArgsTuple; + +template <typename RetT, typename... ArgTs> +class FunctionArgsTuple<RetT(ArgTs...)> { +public: + using Type = std::tuple<typename std::decay< + typename std::remove_reference<ArgTs>::type>::type...>; +}; + // ResultTraits provides typedefs and utilities specific to the return type // of functions. template <typename RetT> class ResultTraits { @@ -274,43 +362,132 @@ template <> class ResultTraits<Error> : public ResultTraits<void> {}; template <typename RetT> class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; +// Determines whether an RPC function's defined error return type supports +// error return value. +template <typename T> +class SupportsErrorReturn { +public: + static const bool value = false; +}; + +template <> +class SupportsErrorReturn<Error> { +public: + static const bool value = true; +}; + +template <typename T> +class SupportsErrorReturn<Expected<T>> { +public: + static const bool value = true; +}; + +// RespondHelper packages return values based on whether or not the declared +// RPC function return type supports error returns. +template <bool FuncSupportsErrorReturn> +class RespondHelper; + +// RespondHelper specialization for functions that support error returns. +template <> +class RespondHelper<true> { +public: + + // Send Expected<T>. + template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected<HandlerRetT> ResultOrErr) { + if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>()) + return ResultOrErr.takeError(); + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits<ChannelT, WireRetT, + Expected<HandlerRetT>>::serialize( + C, std::move(ResultOrErr))) + return Err; + + // Close the response message. + return C.endSendMessage(); + } + + template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err && Err.isA<RPCFatalError>()) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + if (auto Err2 = serializeSeq(C, std::move(Err))) + return Err2; + return C.endSendMessage(); + } + +}; + +// RespondHelper specialization for functions that do not support error returns. +template <> +class RespondHelper<false> { +public: + + template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected<HandlerRetT> ResultOrErr) { + if (auto Err = ResultOrErr.takeError()) + return Err; + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( + C, *ResultOrErr)) + return Err; + + // Close the response message. + return C.endSendMessage(); + } + + template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + return C.endSendMessage(); + } + +}; + + // Send a response of the given wire return type (WireRetT) over the // channel, with the given sequence number. template <typename WireRetT, typename HandlerRetT, typename ChannelT, typename FunctionIdT, typename SequenceNumberT> -static Error respond(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) { - // If this was an error bail out. - // FIXME: Send an "error" message to the client if this is not a channel - // failure? - if (auto Err = ResultOrErr.takeError()) - return Err; - - // Open the response message. - if (auto Err = C.startSendMessage(ResponseId, SeqNo)) - return Err; - - // Serialize the result. - if (auto Err = - SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( - C, *ResultOrErr)) - return Err; - - // Close the response message. - return C.endSendMessage(); +Error respond(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) { + return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: + template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr)); } // Send an empty response message on the given channel to indicate that // the handler ran. template <typename WireRetT, typename ChannelT, typename FunctionIdT, typename SequenceNumberT> -static Error respond(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, Error Err) { - if (Err) - return Err; - if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) - return Err2; - return C.endSendMessage(); +Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, + Error Err) { + return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: + sendResult(C, ResponseId, SeqNo, std::move(Err)); } // Converts a given type to the equivalent error return type. @@ -339,6 +516,29 @@ public: using Type = Error; }; +// Traits class that strips the response function from the list of handler +// arguments. +template <typename FnT> class AsyncHandlerTraits; + +template <typename ResultT, typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Expected<ResultT>; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template <typename ResponseHandlerT, typename... ArgTs> +class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> : + public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type, + ArgTs...)> {}; + // This template class provides utilities related to RPC function handlers. // The base case applies to non-function types (the template class is // specialized for function types) and inherits from the appropriate @@ -358,15 +558,20 @@ public: // Return type of the handler. using ReturnType = RetT; - // A std::tuple wrapping the handler arguments. - using ArgStorage = typename FunctionArgsTuple<RetT(ArgTs...)>::Type; - // Call the given handler with the given arguments. - template <typename HandlerT> + template <typename HandlerT, typename... TArgTs> static typename WrappedHandlerReturn<RetT>::Type - unpackAndRun(HandlerT &Handler, ArgStorage &Args) { + unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) { return unpackAndRunHelper(Handler, Args, - llvm::index_sequence_for<ArgTs...>()); + llvm::index_sequence_for<TArgTs...>()); + } + + // Call the given handler with the given arguments. + template <typename HandlerT, typename ResponderT, typename... TArgTs> + static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, + std::tuple<TArgTs...> &Args) { + return unpackAndRunAsyncHelper(Handler, Responder, Args, + llvm::index_sequence_for<TArgTs...>()); } // Call the given handler with the given arguments. @@ -379,11 +584,11 @@ public: return Error::success(); } - template <typename HandlerT> + template <typename HandlerT, typename... TArgTs> static typename std::enable_if< !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, typename HandlerTraits<HandlerT>::ReturnType>::type - run(HandlerT &Handler, ArgTs... Args) { + run(HandlerT &Handler, TArgTs... Args) { return Handler(std::move(Args)...); } @@ -408,15 +613,31 @@ private: C, std::get<Indexes>(Args)...); } - template <typename HandlerT, size_t... Indexes> + template <typename HandlerT, typename ArgTuple, size_t... Indexes> static typename WrappedHandlerReturn< typename HandlerTraits<HandlerT>::ReturnType>::Type - unpackAndRunHelper(HandlerT &Handler, ArgStorage &Args, + unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, llvm::index_sequence<Indexes...>) { return run(Handler, std::move(std::get<Indexes>(Args))...); } + + + template <typename HandlerT, typename ResponderT, typename ArgTuple, + size_t... Indexes> + static typename WrappedHandlerReturn< + typename HandlerTraits<HandlerT>::ReturnType>::Type + unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, + ArgTuple &Args, + llvm::index_sequence<Indexes...>) { + return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); + } }; +// Handler traits for free functions. +template <typename RetT, typename... ArgTs> +class HandlerTraits<RetT(*)(ArgTs...)> + : public HandlerTraits<RetT(ArgTs...)> {}; + // Handler traits for class methods (especially call operators for lambdas). template <typename Class, typename RetT, typename... ArgTs> class HandlerTraits<RetT (Class::*)(ArgTs...)> @@ -471,7 +692,7 @@ public: // Create an error instance representing an abandoned response. static Error createAbandonedResponseError() { - return orcError(OrcErrorCode::RPCResponseAbandoned); + return make_error<ResponseAbandoned>(); } }; @@ -493,7 +714,7 @@ public: return Err; if (auto Err = C.endReceiveMessage()) return Err; - return Handler(Result); + return Handler(std::move(Result)); } // Abandon this response by calling the handler with an 'abandoned response' @@ -538,6 +759,72 @@ private: HandlerT Handler; }; +template <typename ChannelT, typename FuncRetT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using HandlerArgType = typename ResponseHandlerArg< + typename HandlerTraits<HandlerT>::Type>::ArgType; + HandlerArgType Result((typename HandlerArgType::value_type())); + + if (auto Err = + SerializationTraits<ChannelT, Expected<FuncRetT>, + HandlerArgType>::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +template <typename ChannelT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, Error, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + Error Result = Error::success(); + if (auto Err = + SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + // Create a ResponseHandler from a given user handler. template <typename ChannelT, typename FuncRetT, typename HandlerT> std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) { @@ -758,8 +1045,13 @@ public: auto NegotiateId = FnIdAllocator.getNegotiateId(); RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( - [this](const std::string &Name) { return handleNegotiate(Name); }, - LaunchPolicy()); + [this](const std::string &Name) { return handleNegotiate(Name); }); + } + + + /// Negotiate a function id for Func with the other end of the channel. + template <typename Func> Error negotiateFunction(bool Retry = false) { + return getRemoteFunctionId<Func>(true, Retry).takeError(); } /// Append a call Func, does not call send on the channel. @@ -777,14 +1069,12 @@ public: // Look up the function ID. FunctionIdT FnId; - if (auto FnIdOrErr = getRemoteFunctionId<Func>()) + if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false)) FnId = *FnIdOrErr; else { - // This isn't a channel error so we don't want to abandon other pending - // responses, but we still need to run the user handler with an error to - // let them know the call failed. - if (auto Err = Handler(orcError(OrcErrorCode::UnknownRPCFunction))) - report_fatal_error(std::move(Err)); + // Negotiation failed. Notify the handler then return the negotiate-failed + // error. + cantFail(Handler(make_error<ResponseAbandoned>())); return FnIdOrErr.takeError(); } @@ -807,20 +1097,20 @@ public: // Open the function call message. if (auto Err = C.startSendMessage(FnId, SeqNo)) { abandonPendingResponses(); - return joinErrors(std::move(Err), C.endSendMessage()); + return Err; } // Serialize the call arguments. if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs( C, Args...)) { abandonPendingResponses(); - return joinErrors(std::move(Err), C.endSendMessage()); + return Err; } // Close the function call messagee. if (auto Err = C.endSendMessage()) { abandonPendingResponses(); - return std::move(Err); + return Err; } return Error::success(); @@ -839,8 +1129,10 @@ public: Error handleOne() { FunctionIdT FnId; SequenceNumberT SeqNo; - if (auto Err = C.startReceiveMessage(FnId, SeqNo)) + if (auto Err = C.startReceiveMessage(FnId, SeqNo)) { + abandonPendingResponses(); return Err; + } if (FnId == ResponseId) return handleResponse(SeqNo); auto I = Handlers.find(FnId); @@ -848,7 +1140,8 @@ public: return I->second(C, SeqNo); // else: No handler found. Report error to client? - return orcError(OrcErrorCode::UnexpectedRPCCall); + return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId, + SeqNo); } /// Helper for handling setter procedures - this method returns a functor that @@ -887,10 +1180,25 @@ public: SequenceNumberMgr.reset(); } + /// Remove the handler for the given function. + /// A handler must currently be registered for this function. + template <typename Func> + void removeHandler() { + auto IdItr = LocalFunctionIds.find(Func::getPrototype()); + assert(IdItr != LocalFunctionIds.end() && + "Function does not have a registered handler"); + auto HandlerItr = Handlers.find(IdItr->second); + assert(HandlerItr != Handlers.end() && + "Function does not have a registered handler"); + Handlers.erase(HandlerItr); + } + + /// Clear all handlers. + void clearHandlers() { + Handlers.clear(); + } + protected: - // The LaunchPolicy type allows a launch policy to be specified when adding - // a function handler. See addHandlerImpl. - using LaunchPolicy = std::function<Error(std::function<Error()>)>; FunctionIdT getInvalidFunctionId() const { return FnIdAllocator.getInvalidId(); @@ -899,7 +1207,7 @@ protected: /// Add the given handler to the handler map and make it available for /// autonegotiation and execution. template <typename Func, typename HandlerT> - void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) { + void addHandlerImpl(HandlerT Handler) { static_assert(detail::RPCArgTypeCheck< CanDeserializeCheck, typename Func::Type, @@ -908,8 +1216,22 @@ protected: FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); LocalFunctionIds[Func::getPrototype()] = NewFnId; - Handlers[NewFnId] = - wrapHandler<Func>(std::move(Handler), std::move(Launch)); + Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandlerImpl(HandlerT Handler) { + + static_assert(detail::RPCArgTypeCheck< + CanDeserializeCheck, typename Func::Type, + typename detail::AsyncHandlerTraits< + typename detail::HandlerTraits<HandlerT>::Type + >::Type>::value, + ""); + + FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); + LocalFunctionIds[Func::getPrototype()] = NewFnId; + Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler)); } Error handleResponse(SequenceNumberT SeqNo) { @@ -929,7 +1251,8 @@ protected: // Unlock the pending results map to prevent recursive lock. Lock.unlock(); abandonPendingResponses(); - return orcError(OrcErrorCode::UnexpectedRPCResponse); + return make_error< + InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo); } } @@ -951,41 +1274,39 @@ protected: return I->second; } - // Find the remote FunctionId for the given function, which must be in the - // RemoteFunctionIds map. - template <typename Func> Expected<FunctionIdT> getRemoteFunctionId() { - // Try to find the id for the given function. - auto I = RemoteFunctionIds.find(Func::getPrototype()); + // Find the remote FunctionId for the given function. + template <typename Func> + Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap, + bool NegotiateIfInvalid) { + bool DoNegotiate; - // If we have it in the map, return it. - if (I != RemoteFunctionIds.end()) - return I->second; + // Check if we already have a function id... + auto I = RemoteFunctionIds.find(Func::getPrototype()); + if (I != RemoteFunctionIds.end()) { + // If it's valid there's nothing left to do. + if (I->second != getInvalidFunctionId()) + return I->second; + DoNegotiate = NegotiateIfInvalid; + } else + DoNegotiate = NegotiateIfNotInMap; - // Otherwise, if we have auto-negotiation enabled, try to negotiate it. - if (LazyAutoNegotiation) { + // We don't have a function id for Func yet, but we're allowed to try to + // negotiate one. + if (DoNegotiate) { auto &Impl = static_cast<ImplT &>(*this); if (auto RemoteIdOrErr = - Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { - auto &RemoteId = *RemoteIdOrErr; - - // If autonegotiation indicates that the remote end doesn't support this - // function, return an unknown function error. - if (RemoteId == getInvalidFunctionId()) - return orcError(OrcErrorCode::UnknownRPCFunction); - - // Autonegotiation succeeded and returned a valid id. Update the map and - // return the id. - RemoteFunctionIds[Func::getPrototype()] = RemoteId; - return RemoteId; - } else { - // Autonegotiation failed. Return the error. + Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { + RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + if (*RemoteIdOrErr == getInvalidFunctionId()) + return make_error<CouldNotNegotiate>(Func::getPrototype()); + return *RemoteIdOrErr; + } else return RemoteIdOrErr.takeError(); - } } - // No key was available in the map and autonegotiation wasn't enabled. - // Return an unknown function error. - return orcError(OrcErrorCode::UnknownRPCFunction); + // No key was available in the map and we weren't allowed to try to + // negotiate one, so return an unknown function error. + return make_error<CouldNotNegotiate>(Func::getPrototype()); } using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>; @@ -993,12 +1314,15 @@ protected: // Wrap the given user handler in the necessary argument-deserialization code, // result-serialization code, and call to the launch policy (if present). template <typename Func, typename HandlerT> - WrappedHandlerFn wrapHandler(HandlerT Handler, LaunchPolicy Launch) { - return [this, Handler, Launch](ChannelT &Channel, - SequenceNumberT SeqNo) mutable -> Error { + WrappedHandlerFn wrapHandler(HandlerT Handler) { + return [this, Handler](ChannelT &Channel, + SequenceNumberT SeqNo) mutable -> Error { // Start by deserializing the arguments. - auto Args = std::make_shared< - typename detail::HandlerTraits<HandlerT>::ArgStorage>(); + using ArgsTuple = + typename detail::FunctionArgsTuple< + typename detail::HandlerTraits<HandlerT>::Type>::Type; + auto Args = std::make_shared<ArgsTuple>(); + if (auto Err = detail::HandlerTraits<typename Func::Type>::deserializeArgs( Channel, *Args)) @@ -1013,22 +1337,49 @@ protected: if (auto Err = Channel.endReceiveMessage()) return Err; - // Build the handler/responder. - auto Responder = [this, Handler, Args, &Channel, - SeqNo]() mutable -> Error { - using HTraits = detail::HandlerTraits<HandlerT>; - using FuncReturn = typename Func::ReturnType; - return detail::respond<FuncReturn>( - Channel, ResponseId, SeqNo, HTraits::unpackAndRun(Handler, *Args)); - }; - - // If there is an explicit launch policy then use it to launch the - // handler. - if (Launch) - return Launch(std::move(Responder)); - - // Otherwise run the handler on the listener thread. - return Responder(); + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, + HTraits::unpackAndRun(Handler, *Args)); + }; + } + + // Wrap the given user handler in the necessary argument-deserialization code, + // result-serialization code, and call to the launch policy (if present). + template <typename Func, typename HandlerT> + WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { + return [this, Handler](ChannelT &Channel, + SequenceNumberT SeqNo) mutable -> Error { + // Start by deserializing the arguments. + using AHTraits = detail::AsyncHandlerTraits< + typename detail::HandlerTraits<HandlerT>::Type>; + using ArgsTuple = + typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type; + auto Args = std::make_shared<ArgsTuple>(); + + if (auto Err = + detail::HandlerTraits<typename Func::Type>::deserializeArgs( + Channel, *Args)) + return Err; + + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)Args; + + // End receieve message, unlocking the channel for reading. + if (auto Err = Channel.endReceiveMessage()) + return Err; + + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + auto Responder = + [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error { + return detail::respond<FuncReturn>(C, ResponseId, SeqNo, + std::move(RetVal)); + }; + + return HTraits::unpackAndRunAsync(Handler, Responder, *Args); }; } @@ -1068,66 +1419,31 @@ public: MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) : BaseClass(C, LazyAutoNegotiation) {} - /// The LaunchPolicy type allows a launch policy to be specified when adding - /// a function handler. See addHandler. - using LaunchPolicy = typename BaseClass::LaunchPolicy; - /// Add a handler for the given RPC function. /// This installs the given handler functor for the given RPC Function, and /// makes the RPC function available for negotiation/calling from the remote. - /// - /// The optional LaunchPolicy argument can be used to control how the handler - /// is run when called: - /// - /// * If no LaunchPolicy is given, the handler code will be run on the RPC - /// handler thread that is reading from the channel. This handler cannot - /// make blocking RPC calls (since it would be blocking the thread used to - /// get the result), but can make non-blocking calls. - /// - /// * If a LaunchPolicy is given, the user's handler will be wrapped in a - /// call to serialize and send the result, and the resulting functor (with - /// type 'Error()' will be passed to the LaunchPolicy. The user can then - /// choose to add the wrapped handler to a work queue, spawn a new thread, - /// or anything else. template <typename Func, typename HandlerT> - void addHandler(HandlerT Handler, LaunchPolicy Launch = LaunchPolicy()) { - return this->template addHandlerImpl<Func>(std::move(Handler), - std::move(Launch)); + void addHandler(HandlerT Handler) { + return this->template addHandlerImpl<Func>(std::move(Handler)); } /// Add a class-method as a handler. template <typename Func, typename ClassT, typename RetT, typename... ArgTs> - void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...), - LaunchPolicy Launch = LaunchPolicy()) { + void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { addHandler<Func>( - detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method), - Launch); + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); } - /// Negotiate a function id for Func with the other end of the channel. - template <typename Func> Error negotiateFunction(bool Retry = false) { - using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; - - // Check if we already have a function id... - auto I = this->RemoteFunctionIds.find(Func::getPrototype()); - if (I != this->RemoteFunctionIds.end()) { - // If it's valid there's nothing left to do. - if (I->second != this->getInvalidFunctionId()) - return Error::success(); - // If it's invalid and we can't re-attempt negotiation, throw an error. - if (!Retry) - return orcError(OrcErrorCode::UnknownRPCFunction); - } + template <typename Func, typename HandlerT> + void addAsyncHandler(HandlerT Handler) { + return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); + } - // We don't have a function id for Func yet, call the remote to try to - // negotiate one. - if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) { - this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; - if (*RemoteIdOrErr == this->getInvalidFunctionId()) - return orcError(OrcErrorCode::UnknownRPCFunction); - return Error::success(); - } else - return RemoteIdOrErr.takeError(); + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addAsyncHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); } /// Return type for non-blocking call primitives. @@ -1159,7 +1475,6 @@ public: return Error::success(); }, Args...)) { - this->abandonPendingResponses(); RTraits::consumeAbandoned(FutureResult.get()); return std::move(Err); } @@ -1191,15 +1506,9 @@ public: typename AltRetT = typename Func::ReturnType> typename detail::ResultTraits<AltRetT>::ErrorReturnType callB(const ArgTs &... Args) { - if (auto FutureResOrErr = callNB<Func>(Args...)) { - if (auto Err = this->C.send()) { - this->abandonPendingResponses(); - detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( - std::move(FutureResOrErr->get())); - return std::move(Err); - } + if (auto FutureResOrErr = callNB<Func>(Args...)) return FutureResOrErr->get(); - } else + else return FutureResOrErr.takeError(); } @@ -1224,16 +1533,13 @@ private: SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, ChannelT, FunctionIdT, SequenceNumberT>; - using LaunchPolicy = typename BaseClass::LaunchPolicy; - public: SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) : BaseClass(C, LazyAutoNegotiation) {} template <typename Func, typename HandlerT> void addHandler(HandlerT Handler) { - return this->template addHandlerImpl<Func>(std::move(Handler), - LaunchPolicy()); + return this->template addHandlerImpl<Func>(std::move(Handler)); } template <typename Func, typename ClassT, typename RetT, typename... ArgTs> @@ -1242,30 +1548,16 @@ public: detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); } - /// Negotiate a function id for Func with the other end of the channel. - template <typename Func> Error negotiateFunction(bool Retry = false) { - using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; - - // Check if we already have a function id... - auto I = this->RemoteFunctionIds.find(Func::getPrototype()); - if (I != this->RemoteFunctionIds.end()) { - // If it's valid there's nothing left to do. - if (I->second != this->getInvalidFunctionId()) - return Error::success(); - // If it's invalid and we can't re-attempt negotiation, throw an error. - if (!Retry) - return orcError(OrcErrorCode::UnknownRPCFunction); - } + template <typename Func, typename HandlerT> + void addAsyncHandler(HandlerT Handler) { + return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); + } - // We don't have a function id for Func yet, call the remote to try to - // negotiate one. - if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) { - this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; - if (*RemoteIdOrErr == this->getInvalidFunctionId()) - return orcError(OrcErrorCode::UnknownRPCFunction); - return Error::success(); - } else - return RemoteIdOrErr.takeError(); + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addAsyncHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); } template <typename Func, typename... ArgTs, @@ -1287,7 +1579,6 @@ public: return Error::success(); }, Args...)) { - this->abandonPendingResponses(); detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( std::move(Result)); return std::move(Err); @@ -1295,7 +1586,6 @@ public: while (!ReceivedResponse) { if (auto Err = this->handleOne()) { - this->abandonPendingResponses(); detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( std::move(Result)); return std::move(Err); @@ -1306,24 +1596,40 @@ public: } }; +/// Asynchronous dispatch for a function on an RPC endpoint. +template <typename RPCClass, typename Func> +class RPCAsyncDispatch { +public: + RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {} + + template <typename HandlerT, typename... ArgTs> + Error operator()(HandlerT Handler, const ArgTs &... Args) const { + return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...); + } + +private: + RPCClass &Endpoint; +}; + +/// Construct an asynchronous dispatcher from an RPC endpoint and a Func. +template <typename Func, typename RPCEndpointT> +RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) { + return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint); +} + /// \brief Allows a set of asynchrounous calls to be dispatched, and then /// waited on as a group. -template <typename RPCClass> class ParallelCallGroup { +class ParallelCallGroup { public: - /// \brief Construct a parallel call group for the given RPC. - ParallelCallGroup(RPCClass &RPC) : RPC(RPC), NumOutstandingCalls(0) {} - + ParallelCallGroup() = default; ParallelCallGroup(const ParallelCallGroup &) = delete; ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; /// \brief Make as asynchronous call. - /// - /// Does not issue a send call to the RPC's channel. The channel may use this - /// to batch up subsequent calls. A send will automatically be sent when wait - /// is called. - template <typename Func, typename HandlerT, typename... ArgTs> - Error appendCall(HandlerT Handler, const ArgTs &... Args) { + template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs> + Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler, + const ArgTs &... Args) { // Increment the count of outstanding calls. This has to happen before // we invoke the call, as the handler may (depending on scheduling) // be run immediately on another thread, and we don't want the decrement @@ -1346,38 +1652,21 @@ public: return Err; }; - return RPC.template appendCallAsync<Func>(std::move(WrappedHandler), - Args...); - } - - /// \brief Make an asynchronous call. - /// - /// The same as appendCall, but also calls send on the channel immediately. - /// Prefer appendCall if you are about to issue a "wait" call shortly, as - /// this may allow the channel to better batch the calls. - template <typename Func, typename HandlerT, typename... ArgTs> - Error call(HandlerT Handler, const ArgTs &... Args) { - if (auto Err = appendCall(std::move(Handler), Args...)) - return Err; - return RPC.sendAppendedCalls(); + return AsyncDispatch(std::move(WrappedHandler), Args...); } /// \brief Blocks until all calls have been completed and their return value /// handlers run. - Error wait() { - if (auto Err = RPC.sendAppendedCalls()) - return Err; + void wait() { std::unique_lock<std::mutex> Lock(M); while (NumOutstandingCalls > 0) CV.wait(Lock); - return Error::success(); } private: - RPCClass &RPC; std::mutex M; std::condition_variable CV; - uint32_t NumOutstandingCalls; + uint32_t NumOutstandingCalls = 0; }; /// @brief Convenience class for grouping RPC Functions into APIs that can be |