aboutsummaryrefslogtreecommitdiff
path: root/include/llvm/ExecutionEngine/Orc/RPCUtils.h
diff options
context:
space:
mode:
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc/RPCUtils.h')
-rw-r--r--include/llvm/ExecutionEngine/Orc/RPCUtils.h787
1 files changed, 538 insertions, 249 deletions
diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h
index 37e2e66e5af4..6212f64ff319 100644
--- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h
+++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h
@@ -26,27 +26,115 @@
#include "llvm/ExecutionEngine/Orc/OrcError.h"
#include "llvm/ExecutionEngine/Orc/RPCSerialization.h"
-#ifdef _MSC_VER
-// concrt.h depends on eh.h for __uncaught_exception declaration
-// even if we disable exceptions.
-#include <eh.h>
-
-// Disable warnings from ppltasks.h transitively included by <future>.
-#pragma warning(push)
-#pragma warning(disable : 4530)
-#pragma warning(disable : 4062)
-#endif
-
#include <future>
-#ifdef _MSC_VER
-#pragma warning(pop)
-#endif
-
namespace llvm {
namespace orc {
namespace rpc {
+/// Base class of all fatal RPC errors (those that necessarily result in the
+/// termination of the RPC session).
+class RPCFatalError : public ErrorInfo<RPCFatalError> {
+public:
+ static char ID;
+};
+
+/// RPCConnectionClosed is returned from RPC operations if the RPC connection
+/// has already been closed due to either an error or graceful disconnection.
+class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
+public:
+ static char ID;
+ std::error_code convertToErrorCode() const override;
+ void log(raw_ostream &OS) const override;
+};
+
+/// BadFunctionCall is returned from handleOne when the remote makes a call with
+/// an unrecognized function id.
+///
+/// This error is fatal because Orc RPC needs to know how to parse a function
+/// call to know where the next call starts, and if it doesn't recognize the
+/// function id it cannot parse the call.
+template <typename FnIdT, typename SeqNoT>
+class BadFunctionCall
+ : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
+public:
+ static char ID;
+
+ BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
+ : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
+
+ std::error_code convertToErrorCode() const override {
+ return orcError(OrcErrorCode::UnexpectedRPCCall);
+ }
+
+ void log(raw_ostream &OS) const override {
+ OS << "Call to invalid RPC function id '" << FnId << "' with "
+ "sequence number " << SeqNo;
+ }
+
+private:
+ FnIdT FnId;
+ SeqNoT SeqNo;
+};
+
+template <typename FnIdT, typename SeqNoT>
+char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
+
+/// InvalidSequenceNumberForResponse is returned from handleOne when a response
+/// call arrives with a sequence number that doesn't correspond to any in-flight
+/// function call.
+///
+/// This error is fatal because Orc RPC needs to know how to parse the rest of
+/// the response call to know where the next call starts, and if it doesn't have
+/// a result parser for this sequence number it can't do that.
+template <typename SeqNoT>
+class InvalidSequenceNumberForResponse
+ : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> {
+public:
+ static char ID;
+
+ InvalidSequenceNumberForResponse(SeqNoT SeqNo)
+ : SeqNo(std::move(SeqNo)) {}
+
+ std::error_code convertToErrorCode() const override {
+ return orcError(OrcErrorCode::UnexpectedRPCCall);
+ };
+
+ void log(raw_ostream &OS) const override {
+ OS << "Response has unknown sequence number " << SeqNo;
+ }
+private:
+ SeqNoT SeqNo;
+};
+
+template <typename SeqNoT>
+char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
+
+/// This non-fatal error will be passed to asynchronous result handlers in place
+/// of a result if the connection goes down before a result returns, or if the
+/// function to be called cannot be negotiated with the remote.
+class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
+public:
+ static char ID;
+
+ std::error_code convertToErrorCode() const override;
+ void log(raw_ostream &OS) const override;
+};
+
+/// This error is returned if the remote does not have a handler installed for
+/// the given RPC function.
+class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
+public:
+ static char ID;
+
+ CouldNotNegotiate(std::string Signature);
+ std::error_code convertToErrorCode() const override;
+ void log(raw_ostream &OS) const override;
+ const std::string &getSignature() const { return Signature; }
+private:
+ std::string Signature;
+};
+
template <typename DerivedFunc, typename FnT> class Function;
// RPC Function class.
@@ -82,16 +170,6 @@ std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex;
template <typename DerivedFunc, typename RetT, typename... ArgTs>
std::string Function<DerivedFunc, RetT(ArgTs...)>::Name;
-/// Provides a typedef for a tuple containing the decayed argument types.
-template <typename T> class FunctionArgsTuple;
-
-template <typename RetT, typename... ArgTs>
-class FunctionArgsTuple<RetT(ArgTs...)> {
-public:
- using Type = std::tuple<typename std::decay<
- typename std::remove_reference<ArgTs>::type>::type...>;
-};
-
/// Allocates RPC function ids during autonegotiation.
/// Specializations of this class must provide four members:
///
@@ -196,6 +274,16 @@ public:
#endif // _MSC_VER
+/// Provides a typedef for a tuple containing the decayed argument types.
+template <typename T> class FunctionArgsTuple;
+
+template <typename RetT, typename... ArgTs>
+class FunctionArgsTuple<RetT(ArgTs...)> {
+public:
+ using Type = std::tuple<typename std::decay<
+ typename std::remove_reference<ArgTs>::type>::type...>;
+};
+
// ResultTraits provides typedefs and utilities specific to the return type
// of functions.
template <typename RetT> class ResultTraits {
@@ -274,43 +362,132 @@ template <> class ResultTraits<Error> : public ResultTraits<void> {};
template <typename RetT>
class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
+// Determines whether an RPC function's defined error return type supports
+// error return value.
+template <typename T>
+class SupportsErrorReturn {
+public:
+ static const bool value = false;
+};
+
+template <>
+class SupportsErrorReturn<Error> {
+public:
+ static const bool value = true;
+};
+
+template <typename T>
+class SupportsErrorReturn<Expected<T>> {
+public:
+ static const bool value = true;
+};
+
+// RespondHelper packages return values based on whether or not the declared
+// RPC function return type supports error returns.
+template <bool FuncSupportsErrorReturn>
+class RespondHelper;
+
+// RespondHelper specialization for functions that support error returns.
+template <>
+class RespondHelper<true> {
+public:
+
+ // Send Expected<T>.
+ template <typename WireRetT, typename HandlerRetT, typename ChannelT,
+ typename FunctionIdT, typename SequenceNumberT>
+ static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
+ SequenceNumberT SeqNo,
+ Expected<HandlerRetT> ResultOrErr) {
+ if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
+ return ResultOrErr.takeError();
+
+ // Open the response message.
+ if (auto Err = C.startSendMessage(ResponseId, SeqNo))
+ return Err;
+
+ // Serialize the result.
+ if (auto Err =
+ SerializationTraits<ChannelT, WireRetT,
+ Expected<HandlerRetT>>::serialize(
+ C, std::move(ResultOrErr)))
+ return Err;
+
+ // Close the response message.
+ return C.endSendMessage();
+ }
+
+ template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
+ static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
+ SequenceNumberT SeqNo, Error Err) {
+ if (Err && Err.isA<RPCFatalError>())
+ return Err;
+ if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
+ return Err2;
+ if (auto Err2 = serializeSeq(C, std::move(Err)))
+ return Err2;
+ return C.endSendMessage();
+ }
+
+};
+
+// RespondHelper specialization for functions that do not support error returns.
+template <>
+class RespondHelper<false> {
+public:
+
+ template <typename WireRetT, typename HandlerRetT, typename ChannelT,
+ typename FunctionIdT, typename SequenceNumberT>
+ static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
+ SequenceNumberT SeqNo,
+ Expected<HandlerRetT> ResultOrErr) {
+ if (auto Err = ResultOrErr.takeError())
+ return Err;
+
+ // Open the response message.
+ if (auto Err = C.startSendMessage(ResponseId, SeqNo))
+ return Err;
+
+ // Serialize the result.
+ if (auto Err =
+ SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
+ C, *ResultOrErr))
+ return Err;
+
+ // Close the response message.
+ return C.endSendMessage();
+ }
+
+ template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
+ static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
+ SequenceNumberT SeqNo, Error Err) {
+ if (Err)
+ return Err;
+ if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
+ return Err2;
+ return C.endSendMessage();
+ }
+
+};
+
+
// Send a response of the given wire return type (WireRetT) over the
// channel, with the given sequence number.
template <typename WireRetT, typename HandlerRetT, typename ChannelT,
typename FunctionIdT, typename SequenceNumberT>
-static Error respond(ChannelT &C, const FunctionIdT &ResponseId,
- SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
- // If this was an error bail out.
- // FIXME: Send an "error" message to the client if this is not a channel
- // failure?
- if (auto Err = ResultOrErr.takeError())
- return Err;
-
- // Open the response message.
- if (auto Err = C.startSendMessage(ResponseId, SeqNo))
- return Err;
-
- // Serialize the result.
- if (auto Err =
- SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
- C, *ResultOrErr))
- return Err;
-
- // Close the response message.
- return C.endSendMessage();
+Error respond(ChannelT &C, const FunctionIdT &ResponseId,
+ SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
+ return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
+ template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
}
// Send an empty response message on the given channel to indicate that
// the handler ran.
template <typename WireRetT, typename ChannelT, typename FunctionIdT,
typename SequenceNumberT>
-static Error respond(ChannelT &C, const FunctionIdT &ResponseId,
- SequenceNumberT SeqNo, Error Err) {
- if (Err)
- return Err;
- if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
- return Err2;
- return C.endSendMessage();
+Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
+ Error Err) {
+ return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
+ sendResult(C, ResponseId, SeqNo, std::move(Err));
}
// Converts a given type to the equivalent error return type.
@@ -339,6 +516,29 @@ public:
using Type = Error;
};
+// Traits class that strips the response function from the list of handler
+// arguments.
+template <typename FnT> class AsyncHandlerTraits;
+
+template <typename ResultT, typename... ArgTs>
+class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> {
+public:
+ using Type = Error(ArgTs...);
+ using ResultType = Expected<ResultT>;
+};
+
+template <typename... ArgTs>
+class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
+public:
+ using Type = Error(ArgTs...);
+ using ResultType = Error;
+};
+
+template <typename ResponseHandlerT, typename... ArgTs>
+class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> :
+ public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type,
+ ArgTs...)> {};
+
// This template class provides utilities related to RPC function handlers.
// The base case applies to non-function types (the template class is
// specialized for function types) and inherits from the appropriate
@@ -358,15 +558,20 @@ public:
// Return type of the handler.
using ReturnType = RetT;
- // A std::tuple wrapping the handler arguments.
- using ArgStorage = typename FunctionArgsTuple<RetT(ArgTs...)>::Type;
-
// Call the given handler with the given arguments.
- template <typename HandlerT>
+ template <typename HandlerT, typename... TArgTs>
static typename WrappedHandlerReturn<RetT>::Type
- unpackAndRun(HandlerT &Handler, ArgStorage &Args) {
+ unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
return unpackAndRunHelper(Handler, Args,
- llvm::index_sequence_for<ArgTs...>());
+ llvm::index_sequence_for<TArgTs...>());
+ }
+
+ // Call the given handler with the given arguments.
+ template <typename HandlerT, typename ResponderT, typename... TArgTs>
+ static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
+ std::tuple<TArgTs...> &Args) {
+ return unpackAndRunAsyncHelper(Handler, Responder, Args,
+ llvm::index_sequence_for<TArgTs...>());
}
// Call the given handler with the given arguments.
@@ -379,11 +584,11 @@ public:
return Error::success();
}
- template <typename HandlerT>
+ template <typename HandlerT, typename... TArgTs>
static typename std::enable_if<
!std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
typename HandlerTraits<HandlerT>::ReturnType>::type
- run(HandlerT &Handler, ArgTs... Args) {
+ run(HandlerT &Handler, TArgTs... Args) {
return Handler(std::move(Args)...);
}
@@ -408,15 +613,31 @@ private:
C, std::get<Indexes>(Args)...);
}
- template <typename HandlerT, size_t... Indexes>
+ template <typename HandlerT, typename ArgTuple, size_t... Indexes>
static typename WrappedHandlerReturn<
typename HandlerTraits<HandlerT>::ReturnType>::Type
- unpackAndRunHelper(HandlerT &Handler, ArgStorage &Args,
+ unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
llvm::index_sequence<Indexes...>) {
return run(Handler, std::move(std::get<Indexes>(Args))...);
}
+
+
+ template <typename HandlerT, typename ResponderT, typename ArgTuple,
+ size_t... Indexes>
+ static typename WrappedHandlerReturn<
+ typename HandlerTraits<HandlerT>::ReturnType>::Type
+ unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
+ ArgTuple &Args,
+ llvm::index_sequence<Indexes...>) {
+ return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
+ }
};
+// Handler traits for free functions.
+template <typename RetT, typename... ArgTs>
+class HandlerTraits<RetT(*)(ArgTs...)>
+ : public HandlerTraits<RetT(ArgTs...)> {};
+
// Handler traits for class methods (especially call operators for lambdas).
template <typename Class, typename RetT, typename... ArgTs>
class HandlerTraits<RetT (Class::*)(ArgTs...)>
@@ -471,7 +692,7 @@ public:
// Create an error instance representing an abandoned response.
static Error createAbandonedResponseError() {
- return orcError(OrcErrorCode::RPCResponseAbandoned);
+ return make_error<ResponseAbandoned>();
}
};
@@ -493,7 +714,7 @@ public:
return Err;
if (auto Err = C.endReceiveMessage())
return Err;
- return Handler(Result);
+ return Handler(std::move(Result));
}
// Abandon this response by calling the handler with an 'abandoned response'
@@ -538,6 +759,72 @@ private:
HandlerT Handler;
};
+template <typename ChannelT, typename FuncRetT, typename HandlerT>
+class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
+ : public ResponseHandler<ChannelT> {
+public:
+ ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
+
+ // Handle the result by deserializing it from the channel then passing it
+ // to the user defined handler.
+ Error handleResponse(ChannelT &C) override {
+ using HandlerArgType = typename ResponseHandlerArg<
+ typename HandlerTraits<HandlerT>::Type>::ArgType;
+ HandlerArgType Result((typename HandlerArgType::value_type()));
+
+ if (auto Err =
+ SerializationTraits<ChannelT, Expected<FuncRetT>,
+ HandlerArgType>::deserialize(C, Result))
+ return Err;
+ if (auto Err = C.endReceiveMessage())
+ return Err;
+ return Handler(std::move(Result));
+ }
+
+ // Abandon this response by calling the handler with an 'abandoned response'
+ // error.
+ void abandon() override {
+ if (auto Err = Handler(this->createAbandonedResponseError())) {
+ // Handlers should not fail when passed an abandoned response error.
+ report_fatal_error(std::move(Err));
+ }
+ }
+
+private:
+ HandlerT Handler;
+};
+
+template <typename ChannelT, typename HandlerT>
+class ResponseHandlerImpl<ChannelT, Error, HandlerT>
+ : public ResponseHandler<ChannelT> {
+public:
+ ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
+
+ // Handle the result by deserializing it from the channel then passing it
+ // to the user defined handler.
+ Error handleResponse(ChannelT &C) override {
+ Error Result = Error::success();
+ if (auto Err =
+ SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result))
+ return Err;
+ if (auto Err = C.endReceiveMessage())
+ return Err;
+ return Handler(std::move(Result));
+ }
+
+ // Abandon this response by calling the handler with an 'abandoned response'
+ // error.
+ void abandon() override {
+ if (auto Err = Handler(this->createAbandonedResponseError())) {
+ // Handlers should not fail when passed an abandoned response error.
+ report_fatal_error(std::move(Err));
+ }
+ }
+
+private:
+ HandlerT Handler;
+};
+
// Create a ResponseHandler from a given user handler.
template <typename ChannelT, typename FuncRetT, typename HandlerT>
std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
@@ -758,8 +1045,13 @@ public:
auto NegotiateId = FnIdAllocator.getNegotiateId();
RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
- [this](const std::string &Name) { return handleNegotiate(Name); },
- LaunchPolicy());
+ [this](const std::string &Name) { return handleNegotiate(Name); });
+ }
+
+
+ /// Negotiate a function id for Func with the other end of the channel.
+ template <typename Func> Error negotiateFunction(bool Retry = false) {
+ return getRemoteFunctionId<Func>(true, Retry).takeError();
}
/// Append a call Func, does not call send on the channel.
@@ -777,14 +1069,12 @@ public:
// Look up the function ID.
FunctionIdT FnId;
- if (auto FnIdOrErr = getRemoteFunctionId<Func>())
+ if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
FnId = *FnIdOrErr;
else {
- // This isn't a channel error so we don't want to abandon other pending
- // responses, but we still need to run the user handler with an error to
- // let them know the call failed.
- if (auto Err = Handler(orcError(OrcErrorCode::UnknownRPCFunction)))
- report_fatal_error(std::move(Err));
+ // Negotiation failed. Notify the handler then return the negotiate-failed
+ // error.
+ cantFail(Handler(make_error<ResponseAbandoned>()));
return FnIdOrErr.takeError();
}
@@ -807,20 +1097,20 @@ public:
// Open the function call message.
if (auto Err = C.startSendMessage(FnId, SeqNo)) {
abandonPendingResponses();
- return joinErrors(std::move(Err), C.endSendMessage());
+ return Err;
}
// Serialize the call arguments.
if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
C, Args...)) {
abandonPendingResponses();
- return joinErrors(std::move(Err), C.endSendMessage());
+ return Err;
}
// Close the function call messagee.
if (auto Err = C.endSendMessage()) {
abandonPendingResponses();
- return std::move(Err);
+ return Err;
}
return Error::success();
@@ -839,8 +1129,10 @@ public:
Error handleOne() {
FunctionIdT FnId;
SequenceNumberT SeqNo;
- if (auto Err = C.startReceiveMessage(FnId, SeqNo))
+ if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
+ abandonPendingResponses();
return Err;
+ }
if (FnId == ResponseId)
return handleResponse(SeqNo);
auto I = Handlers.find(FnId);
@@ -848,7 +1140,8 @@ public:
return I->second(C, SeqNo);
// else: No handler found. Report error to client?
- return orcError(OrcErrorCode::UnexpectedRPCCall);
+ return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
+ SeqNo);
}
/// Helper for handling setter procedures - this method returns a functor that
@@ -887,10 +1180,25 @@ public:
SequenceNumberMgr.reset();
}
+ /// Remove the handler for the given function.
+ /// A handler must currently be registered for this function.
+ template <typename Func>
+ void removeHandler() {
+ auto IdItr = LocalFunctionIds.find(Func::getPrototype());
+ assert(IdItr != LocalFunctionIds.end() &&
+ "Function does not have a registered handler");
+ auto HandlerItr = Handlers.find(IdItr->second);
+ assert(HandlerItr != Handlers.end() &&
+ "Function does not have a registered handler");
+ Handlers.erase(HandlerItr);
+ }
+
+ /// Clear all handlers.
+ void clearHandlers() {
+ Handlers.clear();
+ }
+
protected:
- // The LaunchPolicy type allows a launch policy to be specified when adding
- // a function handler. See addHandlerImpl.
- using LaunchPolicy = std::function<Error(std::function<Error()>)>;
FunctionIdT getInvalidFunctionId() const {
return FnIdAllocator.getInvalidId();
@@ -899,7 +1207,7 @@ protected:
/// Add the given handler to the handler map and make it available for
/// autonegotiation and execution.
template <typename Func, typename HandlerT>
- void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) {
+ void addHandlerImpl(HandlerT Handler) {
static_assert(detail::RPCArgTypeCheck<
CanDeserializeCheck, typename Func::Type,
@@ -908,8 +1216,22 @@ protected:
FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
LocalFunctionIds[Func::getPrototype()] = NewFnId;
- Handlers[NewFnId] =
- wrapHandler<Func>(std::move(Handler), std::move(Launch));
+ Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
+ }
+
+ template <typename Func, typename HandlerT>
+ void addAsyncHandlerImpl(HandlerT Handler) {
+
+ static_assert(detail::RPCArgTypeCheck<
+ CanDeserializeCheck, typename Func::Type,
+ typename detail::AsyncHandlerTraits<
+ typename detail::HandlerTraits<HandlerT>::Type
+ >::Type>::value,
+ "");
+
+ FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
+ LocalFunctionIds[Func::getPrototype()] = NewFnId;
+ Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
}
Error handleResponse(SequenceNumberT SeqNo) {
@@ -929,7 +1251,8 @@ protected:
// Unlock the pending results map to prevent recursive lock.
Lock.unlock();
abandonPendingResponses();
- return orcError(OrcErrorCode::UnexpectedRPCResponse);
+ return make_error<
+ InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);
}
}
@@ -951,41 +1274,39 @@ protected:
return I->second;
}
- // Find the remote FunctionId for the given function, which must be in the
- // RemoteFunctionIds map.
- template <typename Func> Expected<FunctionIdT> getRemoteFunctionId() {
- // Try to find the id for the given function.
- auto I = RemoteFunctionIds.find(Func::getPrototype());
+ // Find the remote FunctionId for the given function.
+ template <typename Func>
+ Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
+ bool NegotiateIfInvalid) {
+ bool DoNegotiate;
- // If we have it in the map, return it.
- if (I != RemoteFunctionIds.end())
- return I->second;
+ // Check if we already have a function id...
+ auto I = RemoteFunctionIds.find(Func::getPrototype());
+ if (I != RemoteFunctionIds.end()) {
+ // If it's valid there's nothing left to do.
+ if (I->second != getInvalidFunctionId())
+ return I->second;
+ DoNegotiate = NegotiateIfInvalid;
+ } else
+ DoNegotiate = NegotiateIfNotInMap;
- // Otherwise, if we have auto-negotiation enabled, try to negotiate it.
- if (LazyAutoNegotiation) {
+ // We don't have a function id for Func yet, but we're allowed to try to
+ // negotiate one.
+ if (DoNegotiate) {
auto &Impl = static_cast<ImplT &>(*this);
if (auto RemoteIdOrErr =
- Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
- auto &RemoteId = *RemoteIdOrErr;
-
- // If autonegotiation indicates that the remote end doesn't support this
- // function, return an unknown function error.
- if (RemoteId == getInvalidFunctionId())
- return orcError(OrcErrorCode::UnknownRPCFunction);
-
- // Autonegotiation succeeded and returned a valid id. Update the map and
- // return the id.
- RemoteFunctionIds[Func::getPrototype()] = RemoteId;
- return RemoteId;
- } else {
- // Autonegotiation failed. Return the error.
+ Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
+ RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
+ if (*RemoteIdOrErr == getInvalidFunctionId())
+ return make_error<CouldNotNegotiate>(Func::getPrototype());
+ return *RemoteIdOrErr;
+ } else
return RemoteIdOrErr.takeError();
- }
}
- // No key was available in the map and autonegotiation wasn't enabled.
- // Return an unknown function error.
- return orcError(OrcErrorCode::UnknownRPCFunction);
+ // No key was available in the map and we weren't allowed to try to
+ // negotiate one, so return an unknown function error.
+ return make_error<CouldNotNegotiate>(Func::getPrototype());
}
using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
@@ -993,12 +1314,15 @@ protected:
// Wrap the given user handler in the necessary argument-deserialization code,
// result-serialization code, and call to the launch policy (if present).
template <typename Func, typename HandlerT>
- WrappedHandlerFn wrapHandler(HandlerT Handler, LaunchPolicy Launch) {
- return [this, Handler, Launch](ChannelT &Channel,
- SequenceNumberT SeqNo) mutable -> Error {
+ WrappedHandlerFn wrapHandler(HandlerT Handler) {
+ return [this, Handler](ChannelT &Channel,
+ SequenceNumberT SeqNo) mutable -> Error {
// Start by deserializing the arguments.
- auto Args = std::make_shared<
- typename detail::HandlerTraits<HandlerT>::ArgStorage>();
+ using ArgsTuple =
+ typename detail::FunctionArgsTuple<
+ typename detail::HandlerTraits<HandlerT>::Type>::Type;
+ auto Args = std::make_shared<ArgsTuple>();
+
if (auto Err =
detail::HandlerTraits<typename Func::Type>::deserializeArgs(
Channel, *Args))
@@ -1013,22 +1337,49 @@ protected:
if (auto Err = Channel.endReceiveMessage())
return Err;
- // Build the handler/responder.
- auto Responder = [this, Handler, Args, &Channel,
- SeqNo]() mutable -> Error {
- using HTraits = detail::HandlerTraits<HandlerT>;
- using FuncReturn = typename Func::ReturnType;
- return detail::respond<FuncReturn>(
- Channel, ResponseId, SeqNo, HTraits::unpackAndRun(Handler, *Args));
- };
-
- // If there is an explicit launch policy then use it to launch the
- // handler.
- if (Launch)
- return Launch(std::move(Responder));
-
- // Otherwise run the handler on the listener thread.
- return Responder();
+ using HTraits = detail::HandlerTraits<HandlerT>;
+ using FuncReturn = typename Func::ReturnType;
+ return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
+ HTraits::unpackAndRun(Handler, *Args));
+ };
+ }
+
+ // Wrap the given user handler in the necessary argument-deserialization code,
+ // result-serialization code, and call to the launch policy (if present).
+ template <typename Func, typename HandlerT>
+ WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
+ return [this, Handler](ChannelT &Channel,
+ SequenceNumberT SeqNo) mutable -> Error {
+ // Start by deserializing the arguments.
+ using AHTraits = detail::AsyncHandlerTraits<
+ typename detail::HandlerTraits<HandlerT>::Type>;
+ using ArgsTuple =
+ typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type;
+ auto Args = std::make_shared<ArgsTuple>();
+
+ if (auto Err =
+ detail::HandlerTraits<typename Func::Type>::deserializeArgs(
+ Channel, *Args))
+ return Err;
+
+ // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
+ // for RPCArgs. Void cast RPCArgs to work around this for now.
+ // FIXME: Remove this workaround once we can assume a working GCC version.
+ (void)Args;
+
+ // End receieve message, unlocking the channel for reading.
+ if (auto Err = Channel.endReceiveMessage())
+ return Err;
+
+ using HTraits = detail::HandlerTraits<HandlerT>;
+ using FuncReturn = typename Func::ReturnType;
+ auto Responder =
+ [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error {
+ return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
+ std::move(RetVal));
+ };
+
+ return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
};
}
@@ -1068,66 +1419,31 @@ public:
MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
: BaseClass(C, LazyAutoNegotiation) {}
- /// The LaunchPolicy type allows a launch policy to be specified when adding
- /// a function handler. See addHandler.
- using LaunchPolicy = typename BaseClass::LaunchPolicy;
-
/// Add a handler for the given RPC function.
/// This installs the given handler functor for the given RPC Function, and
/// makes the RPC function available for negotiation/calling from the remote.
- ///
- /// The optional LaunchPolicy argument can be used to control how the handler
- /// is run when called:
- ///
- /// * If no LaunchPolicy is given, the handler code will be run on the RPC
- /// handler thread that is reading from the channel. This handler cannot
- /// make blocking RPC calls (since it would be blocking the thread used to
- /// get the result), but can make non-blocking calls.
- ///
- /// * If a LaunchPolicy is given, the user's handler will be wrapped in a
- /// call to serialize and send the result, and the resulting functor (with
- /// type 'Error()' will be passed to the LaunchPolicy. The user can then
- /// choose to add the wrapped handler to a work queue, spawn a new thread,
- /// or anything else.
template <typename Func, typename HandlerT>
- void addHandler(HandlerT Handler, LaunchPolicy Launch = LaunchPolicy()) {
- return this->template addHandlerImpl<Func>(std::move(Handler),
- std::move(Launch));
+ void addHandler(HandlerT Handler) {
+ return this->template addHandlerImpl<Func>(std::move(Handler));
}
/// Add a class-method as a handler.
template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
- void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...),
- LaunchPolicy Launch = LaunchPolicy()) {
+ void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
addHandler<Func>(
- detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method),
- Launch);
+ detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
}
- /// Negotiate a function id for Func with the other end of the channel.
- template <typename Func> Error negotiateFunction(bool Retry = false) {
- using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate;
-
- // Check if we already have a function id...
- auto I = this->RemoteFunctionIds.find(Func::getPrototype());
- if (I != this->RemoteFunctionIds.end()) {
- // If it's valid there's nothing left to do.
- if (I->second != this->getInvalidFunctionId())
- return Error::success();
- // If it's invalid and we can't re-attempt negotiation, throw an error.
- if (!Retry)
- return orcError(OrcErrorCode::UnknownRPCFunction);
- }
+ template <typename Func, typename HandlerT>
+ void addAsyncHandler(HandlerT Handler) {
+ return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
+ }
- // We don't have a function id for Func yet, call the remote to try to
- // negotiate one.
- if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) {
- this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
- if (*RemoteIdOrErr == this->getInvalidFunctionId())
- return orcError(OrcErrorCode::UnknownRPCFunction);
- return Error::success();
- } else
- return RemoteIdOrErr.takeError();
+ /// Add a class-method as a handler.
+ template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
+ void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
+ addAsyncHandler<Func>(
+ detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
}
/// Return type for non-blocking call primitives.
@@ -1159,7 +1475,6 @@ public:
return Error::success();
},
Args...)) {
- this->abandonPendingResponses();
RTraits::consumeAbandoned(FutureResult.get());
return std::move(Err);
}
@@ -1191,15 +1506,9 @@ public:
typename AltRetT = typename Func::ReturnType>
typename detail::ResultTraits<AltRetT>::ErrorReturnType
callB(const ArgTs &... Args) {
- if (auto FutureResOrErr = callNB<Func>(Args...)) {
- if (auto Err = this->C.send()) {
- this->abandonPendingResponses();
- detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
- std::move(FutureResOrErr->get()));
- return std::move(Err);
- }
+ if (auto FutureResOrErr = callNB<Func>(Args...))
return FutureResOrErr->get();
- } else
+ else
return FutureResOrErr.takeError();
}
@@ -1224,16 +1533,13 @@ private:
SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
ChannelT, FunctionIdT, SequenceNumberT>;
- using LaunchPolicy = typename BaseClass::LaunchPolicy;
-
public:
SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
: BaseClass(C, LazyAutoNegotiation) {}
template <typename Func, typename HandlerT>
void addHandler(HandlerT Handler) {
- return this->template addHandlerImpl<Func>(std::move(Handler),
- LaunchPolicy());
+ return this->template addHandlerImpl<Func>(std::move(Handler));
}
template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
@@ -1242,30 +1548,16 @@ public:
detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
}
- /// Negotiate a function id for Func with the other end of the channel.
- template <typename Func> Error negotiateFunction(bool Retry = false) {
- using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate;
-
- // Check if we already have a function id...
- auto I = this->RemoteFunctionIds.find(Func::getPrototype());
- if (I != this->RemoteFunctionIds.end()) {
- // If it's valid there's nothing left to do.
- if (I->second != this->getInvalidFunctionId())
- return Error::success();
- // If it's invalid and we can't re-attempt negotiation, throw an error.
- if (!Retry)
- return orcError(OrcErrorCode::UnknownRPCFunction);
- }
+ template <typename Func, typename HandlerT>
+ void addAsyncHandler(HandlerT Handler) {
+ return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
+ }
- // We don't have a function id for Func yet, call the remote to try to
- // negotiate one.
- if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) {
- this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
- if (*RemoteIdOrErr == this->getInvalidFunctionId())
- return orcError(OrcErrorCode::UnknownRPCFunction);
- return Error::success();
- } else
- return RemoteIdOrErr.takeError();
+ /// Add a class-method as a handler.
+ template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
+ void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
+ addAsyncHandler<Func>(
+ detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
}
template <typename Func, typename... ArgTs,
@@ -1287,7 +1579,6 @@ public:
return Error::success();
},
Args...)) {
- this->abandonPendingResponses();
detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
std::move(Result));
return std::move(Err);
@@ -1295,7 +1586,6 @@ public:
while (!ReceivedResponse) {
if (auto Err = this->handleOne()) {
- this->abandonPendingResponses();
detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
std::move(Result));
return std::move(Err);
@@ -1306,24 +1596,40 @@ public:
}
};
+/// Asynchronous dispatch for a function on an RPC endpoint.
+template <typename RPCClass, typename Func>
+class RPCAsyncDispatch {
+public:
+ RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
+
+ template <typename HandlerT, typename... ArgTs>
+ Error operator()(HandlerT Handler, const ArgTs &... Args) const {
+ return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
+ }
+
+private:
+ RPCClass &Endpoint;
+};
+
+/// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
+template <typename Func, typename RPCEndpointT>
+RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
+ return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
+}
+
/// \brief Allows a set of asynchrounous calls to be dispatched, and then
/// waited on as a group.
-template <typename RPCClass> class ParallelCallGroup {
+class ParallelCallGroup {
public:
- /// \brief Construct a parallel call group for the given RPC.
- ParallelCallGroup(RPCClass &RPC) : RPC(RPC), NumOutstandingCalls(0) {}
-
+ ParallelCallGroup() = default;
ParallelCallGroup(const ParallelCallGroup &) = delete;
ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
/// \brief Make as asynchronous call.
- ///
- /// Does not issue a send call to the RPC's channel. The channel may use this
- /// to batch up subsequent calls. A send will automatically be sent when wait
- /// is called.
- template <typename Func, typename HandlerT, typename... ArgTs>
- Error appendCall(HandlerT Handler, const ArgTs &... Args) {
+ template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
+ Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
+ const ArgTs &... Args) {
// Increment the count of outstanding calls. This has to happen before
// we invoke the call, as the handler may (depending on scheduling)
// be run immediately on another thread, and we don't want the decrement
@@ -1346,38 +1652,21 @@ public:
return Err;
};
- return RPC.template appendCallAsync<Func>(std::move(WrappedHandler),
- Args...);
- }
-
- /// \brief Make an asynchronous call.
- ///
- /// The same as appendCall, but also calls send on the channel immediately.
- /// Prefer appendCall if you are about to issue a "wait" call shortly, as
- /// this may allow the channel to better batch the calls.
- template <typename Func, typename HandlerT, typename... ArgTs>
- Error call(HandlerT Handler, const ArgTs &... Args) {
- if (auto Err = appendCall(std::move(Handler), Args...))
- return Err;
- return RPC.sendAppendedCalls();
+ return AsyncDispatch(std::move(WrappedHandler), Args...);
}
/// \brief Blocks until all calls have been completed and their return value
/// handlers run.
- Error wait() {
- if (auto Err = RPC.sendAppendedCalls())
- return Err;
+ void wait() {
std::unique_lock<std::mutex> Lock(M);
while (NumOutstandingCalls > 0)
CV.wait(Lock);
- return Error::success();
}
private:
- RPCClass &RPC;
std::mutex M;
std::condition_variable CV;
- uint32_t NumOutstandingCalls;
+ uint32_t NumOutstandingCalls = 0;
};
/// @brief Convenience class for grouping RPC Functions into APIs that can be