diff options
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc/RPCUtils.h')
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/RPCUtils.h | 614 |
1 files changed, 521 insertions, 93 deletions
diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 0bd5cbc0cdde0..966a49684348e 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -14,78 +14,256 @@ #ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H +#include <map> +#include <vector> + +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ExecutionEngine/Orc/OrcError.h" + +#ifdef _MSC_VER +// concrt.h depends on eh.h for __uncaught_exception declaration +// even if we disable exceptions. +#include <eh.h> + +// Disable warnings from ppltasks.h transitively included by <future>. +#pragma warning(push) +#pragma warning(disable : 4530) +#pragma warning(disable : 4062) +#endif + +#include <future> + +#ifdef _MSC_VER +#pragma warning(pop) +#endif namespace llvm { namespace orc { namespace remote { +/// Describes reserved RPC Function Ids. +/// +/// The default implementation will serve for integer and enum function id +/// types. If you want to use a custom type as your FunctionId you can +/// specialize this class and provide unique values for InvalidId, +/// ResponseId and FirstValidId. + +template <typename T> class RPCFunctionIdTraits { +public: + static const T InvalidId = static_cast<T>(0); + static const T ResponseId = static_cast<T>(1); + static const T FirstValidId = static_cast<T>(2); +}; + // Base class containing utilities that require partial specialization. // These cannot be included in RPC, as template class members cannot be // partially specialized. class RPCBase { protected: - template <typename ProcedureIdT, ProcedureIdT ProcId, typename... Ts> - class ProcedureHelper { + // RPC Function description type. + // + // This class provides the information and operations needed to support the + // RPC primitive operations (call, expect, etc) for a given function. It + // is specialized for void and non-void functions to deal with the differences + // betwen the two. Both specializations have the same interface: + // + // Id - The function's unique identifier. + // OptionalReturn - The return type for asyncronous calls. + // ErrorReturn - The return type for synchronous calls. + // optionalToErrorReturn - Conversion from a valid OptionalReturn to an + // ErrorReturn. + // readResult - Deserialize a result from a channel. + // abandon - Abandon a promised (asynchronous) result. + // respond - Retun a result on the channel. + template <typename FunctionIdT, FunctionIdT FuncId, typename FnT> + class FunctionHelper {}; + + // RPC Function description specialization for non-void functions. + template <typename FunctionIdT, FunctionIdT FuncId, typename RetT, + typename... ArgTs> + class FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> { public: - static const ProcedureIdT Id = ProcId; + static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId && + FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId, + "Cannot define custom function with InvalidId or ResponseId. " + "Please use RPCFunctionTraits<FunctionIdT>::FirstValidId."); + + static const FunctionIdT Id = FuncId; + + typedef Optional<RetT> OptionalReturn; + + typedef Expected<RetT> ErrorReturn; + + static ErrorReturn optionalToErrorReturn(OptionalReturn &&V) { + assert(V && "Return value not available"); + return std::move(*V); + } + + template <typename ChannelT> + static Error readResult(ChannelT &C, std::promise<OptionalReturn> &P) { + RetT Val; + auto Err = deserialize(C, Val); + auto Err2 = endReceiveMessage(C); + Err = joinErrors(std::move(Err), std::move(Err2)); + + if (Err) { + P.set_value(OptionalReturn()); + return Err; + } + P.set_value(std::move(Val)); + return Error::success(); + } + + static void abandon(std::promise<OptionalReturn> &P) { + P.set_value(OptionalReturn()); + } + + template <typename ChannelT, typename SequenceNumberT> + static Error respond(ChannelT &C, SequenceNumberT SeqNo, + ErrorReturn &Result) { + FunctionIdT ResponseId = RPCFunctionIdTraits<FunctionIdT>::ResponseId; + + // If the handler returned an error then bail out with that. + if (!Result) + return Result.takeError(); + + // Otherwise open a new message on the channel and send the result. + if (auto Err = startSendMessage(C)) + return Err; + if (auto Err = serializeSeq(C, ResponseId, SeqNo, *Result)) + return Err; + return endSendMessage(C); + } }; - template <typename ChannelT, typename Proc> class CallHelper; + // RPC Function description specialization for void functions. + template <typename FunctionIdT, FunctionIdT FuncId, typename... ArgTs> + class FunctionHelper<FunctionIdT, FuncId, void(ArgTs...)> { + public: + static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId && + FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId, + "Cannot define custom function with InvalidId or ResponseId. " + "Please use RPCFunctionTraits<FunctionIdT>::FirstValidId."); - template <typename ChannelT, typename ProcedureIdT, ProcedureIdT ProcId, - typename... ArgTs> - class CallHelper<ChannelT, ProcedureHelper<ProcedureIdT, ProcId, ArgTs...>> { + static const FunctionIdT Id = FuncId; + + typedef bool OptionalReturn; + typedef Error ErrorReturn; + + static ErrorReturn optionalToErrorReturn(OptionalReturn &&V) { + assert(V && "Return value not available"); + return Error::success(); + } + + template <typename ChannelT> + static Error readResult(ChannelT &C, std::promise<OptionalReturn> &P) { + // Void functions don't have anything to deserialize, so we're good. + P.set_value(true); + return endReceiveMessage(C); + } + + static void abandon(std::promise<OptionalReturn> &P) { P.set_value(false); } + + template <typename ChannelT, typename SequenceNumberT> + static Error respond(ChannelT &C, SequenceNumberT SeqNo, + ErrorReturn &Result) { + const FunctionIdT ResponseId = + RPCFunctionIdTraits<FunctionIdT>::ResponseId; + + // If the handler returned an error then bail out with that. + if (Result) + return std::move(Result); + + // Otherwise open a new message on the channel and send the result. + if (auto Err = startSendMessage(C)) + return Err; + if (auto Err = serializeSeq(C, ResponseId, SeqNo)) + return Err; + return endSendMessage(C); + } + }; + + // Helper for the call primitive. + template <typename ChannelT, typename SequenceNumberT, typename Func> + class CallHelper; + + template <typename ChannelT, typename SequenceNumberT, typename FunctionIdT, + FunctionIdT FuncId, typename RetT, typename... ArgTs> + class CallHelper<ChannelT, SequenceNumberT, + FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> { public: - static std::error_code call(ChannelT &C, const ArgTs &... Args) { - if (auto EC = serialize(C, ProcId)) - return EC; - // If you see a compile-error on this line you're probably calling a - // function with the wrong signature. - return serialize_seq(C, Args...); + static Error call(ChannelT &C, SequenceNumberT SeqNo, + const ArgTs &... Args) { + if (auto Err = startSendMessage(C)) + return Err; + if (auto Err = serializeSeq(C, FuncId, SeqNo, Args...)) + return Err; + return endSendMessage(C); } }; - template <typename ChannelT, typename Proc> class HandlerHelper; + // Helper for handle primitive. + template <typename ChannelT, typename SequenceNumberT, typename Func> + class HandlerHelper; - template <typename ChannelT, typename ProcedureIdT, ProcedureIdT ProcId, - typename... ArgTs> - class HandlerHelper<ChannelT, - ProcedureHelper<ProcedureIdT, ProcId, ArgTs...>> { + template <typename ChannelT, typename SequenceNumberT, typename FunctionIdT, + FunctionIdT FuncId, typename RetT, typename... ArgTs> + class HandlerHelper<ChannelT, SequenceNumberT, + FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> { public: template <typename HandlerT> - static std::error_code handle(ChannelT &C, HandlerT Handler) { + static Error handle(ChannelT &C, HandlerT Handler) { return readAndHandle(C, Handler, llvm::index_sequence_for<ArgTs...>()); } private: + typedef FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> Func; + template <typename HandlerT, size_t... Is> - static std::error_code readAndHandle(ChannelT &C, HandlerT Handler, - llvm::index_sequence<Is...> _) { + static Error readAndHandle(ChannelT &C, HandlerT Handler, + llvm::index_sequence<Is...> _) { std::tuple<ArgTs...> RPCArgs; - if (auto EC = deserialize_seq(C, std::get<Is>(RPCArgs)...)) - return EC; - return Handler(std::get<Is>(RPCArgs)...); + SequenceNumberT SeqNo; + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)RPCArgs; + if (auto Err = deserializeSeq(C, SeqNo, std::get<Is>(RPCArgs)...)) + return Err; + + // We've deserialized the arguments, so unlock the channel for reading + // before we call the handler. This allows recursive RPC calls. + if (auto Err = endReceiveMessage(C)) + return Err; + + // Run the handler and get the result. + auto Result = Handler(std::get<Is>(RPCArgs)...); + + // Return the result to the client. + return Func::template respond<ChannelT, SequenceNumberT>(C, SeqNo, + Result); } }; - template <typename ClassT, typename... ArgTs> class MemberFnWrapper { + // Helper for wrapping member functions up as functors. + template <typename ClassT, typename RetT, typename... ArgTs> + class MemberFnWrapper { public: - typedef std::error_code (ClassT::*MethodT)(ArgTs...); + typedef RetT (ClassT::*MethodT)(ArgTs...); MemberFnWrapper(ClassT &Instance, MethodT Method) : Instance(Instance), Method(Method) {} - std::error_code operator()(ArgTs &... Args) { - return (Instance.*Method)(Args...); - } + RetT operator()(ArgTs &... Args) { return (Instance.*Method)(Args...); } private: ClassT &Instance; MethodT Method; }; + // Helper that provides a Functor for deserializing arguments. template <typename... ArgTs> class ReadArgs { public: - std::error_code operator()() { return std::error_code(); } + Error operator()() { return Error::success(); } }; template <typename ArgT, typename... ArgTs> @@ -94,7 +272,7 @@ protected: ReadArgs(ArgT &Arg, ArgTs &... Args) : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} - std::error_code operator()(ArgT &ArgVal, ArgTs &... ArgVals) { + Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) { this->Arg = std::move(ArgVal); return ReadArgs<ArgTs...>::operator()(ArgVals...); } @@ -106,7 +284,7 @@ protected: /// Contains primitive utilities for defining, calling and handling calls to /// remote procedures. ChannelT is a bidirectional stream conforming to the -/// RPCChannel interface (see RPCChannel.h), and ProcedureIdT is a procedure +/// RPCChannel interface (see RPCChannel.h), and FunctionIdT is a procedure /// identifier type that must be serializable on ChannelT. /// /// These utilities support the construction of very primitive RPC utilities. @@ -123,120 +301,223 @@ protected: /// /// Overview (see comments individual types/methods for details): /// -/// Procedure<Id, Args...> : +/// Function<Id, Args...> : /// /// associates a unique serializable id with an argument list. /// /// -/// call<Proc>(Channel, Args...) : +/// call<Func>(Channel, Args...) : /// -/// Calls the remote procedure 'Proc' by serializing Proc's id followed by its +/// Calls the remote procedure 'Func' by serializing Func's id followed by its /// arguments and sending the resulting bytes to 'Channel'. /// /// -/// handle<Proc>(Channel, <functor matching std::error_code(Args...)> : +/// handle<Func>(Channel, <functor matching Error(Args...)> : /// -/// Handles a call to 'Proc' by deserializing its arguments and calling the -/// given functor. This assumes that the id for 'Proc' has already been +/// Handles a call to 'Func' by deserializing its arguments and calling the +/// given functor. This assumes that the id for 'Func' has already been /// deserialized. /// -/// expect<Proc>(Channel, <functor matching std::error_code(Args...)> : +/// expect<Func>(Channel, <functor matching Error(Args...)> : /// /// The same as 'handle', except that the procedure id should not have been -/// read yet. Expect will deserialize the id and assert that it matches Proc's +/// read yet. Expect will deserialize the id and assert that it matches Func's /// id. If it does not, and unexpected RPC call error is returned. - -template <typename ChannelT, typename ProcedureIdT = uint32_t> +template <typename ChannelT, typename FunctionIdT = uint32_t, + typename SequenceNumberT = uint16_t> class RPC : public RPCBase { public: + /// RPC default constructor. + RPC() = default; + + /// RPC instances cannot be copied. + RPC(const RPC &) = delete; + + /// RPC instances cannot be copied. + RPC &operator=(const RPC &) = delete; + + /// RPC move constructor. + // FIXME: Remove once MSVC can synthesize move ops. + RPC(RPC &&Other) + : SequenceNumberMgr(std::move(Other.SequenceNumberMgr)), + OutstandingResults(std::move(Other.OutstandingResults)) {} + + /// RPC move assignment. + // FIXME: Remove once MSVC can synthesize move ops. + RPC &operator=(RPC &&Other) { + SequenceNumberMgr = std::move(Other.SequenceNumberMgr); + OutstandingResults = std::move(Other.OutstandingResults); + return *this; + } + /// Utility class for defining/referring to RPC procedures. /// /// Typedefs of this utility are used when calling/handling remote procedures. /// - /// ProcId should be a unique value of ProcedureIdT (i.e. not used with any - /// other Procedure typedef in the RPC API being defined. + /// FuncId should be a unique value of FunctionIdT (i.e. not used with any + /// other Function typedef in the RPC API being defined. /// /// the template argument Ts... gives the argument list for the remote /// procedure. /// /// E.g. /// - /// typedef Procedure<0, bool> Proc1; - /// typedef Procedure<1, std::string, std::vector<int>> Proc2; + /// typedef Function<0, bool> Func1; + /// typedef Function<1, std::string, std::vector<int>> Func2; /// - /// if (auto EC = call<Proc1>(Channel, true)) - /// /* handle EC */; + /// if (auto Err = call<Func1>(Channel, true)) + /// /* handle Err */; /// - /// if (auto EC = expect<Proc2>(Channel, + /// if (auto Err = expect<Func2>(Channel, /// [](std::string &S, std::vector<int> &V) { /// // Stuff. - /// return std::error_code(); + /// return Error::success(); /// }) - /// /* handle EC */; + /// /* handle Err */; /// - template <ProcedureIdT ProcId, typename... Ts> - using Procedure = ProcedureHelper<ProcedureIdT, ProcId, Ts...>; + template <FunctionIdT FuncId, typename FnT> + using Function = FunctionHelper<FunctionIdT, FuncId, FnT>; + + /// Return type for asynchronous call primitives. + template <typename Func> + using AsyncCallResult = std::future<typename Func::OptionalReturn>; + + /// Return type for asynchronous call-with-seq primitives. + template <typename Func> + using AsyncCallWithSeqResult = + std::pair<std::future<typename Func::OptionalReturn>, SequenceNumberT>; /// Serialize Args... to channel C, but do not call C.send(). /// - /// For buffered channels, this can be used to queue up several calls before - /// flushing the channel. - template <typename Proc, typename... ArgTs> - static std::error_code appendCall(ChannelT &C, const ArgTs &... Args) { - return CallHelper<ChannelT, Proc>::call(C, Args...); + /// Returns an error (on serialization failure) or a pair of: + /// (1) A future Optional<T> (or future<bool> for void functions), and + /// (2) A sequence number. + /// + /// This utility function is primarily used for single-threaded mode support, + /// where the sequence number can be used to wait for the corresponding + /// result. In multi-threaded mode the appendCallAsync method, which does not + /// return the sequence numeber, should be preferred. + template <typename Func, typename... ArgTs> + Expected<AsyncCallWithSeqResult<Func>> + appendCallAsyncWithSeq(ChannelT &C, const ArgTs &... Args) { + auto SeqNo = SequenceNumberMgr.getSequenceNumber(); + std::promise<typename Func::OptionalReturn> Promise; + auto Result = Promise.get_future(); + OutstandingResults[SeqNo] = + createOutstandingResult<Func>(std::move(Promise)); + + if (auto Err = CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo, + Args...)) { + abandonOutstandingResults(); + return std::move(Err); + } else + return AsyncCallWithSeqResult<Func>(std::move(Result), SeqNo); } - /// Serialize Args... to channel C and call C.send(). - template <typename Proc, typename... ArgTs> - static std::error_code call(ChannelT &C, const ArgTs &... Args) { - if (auto EC = appendCall<Proc>(C, Args...)) - return EC; - return C.send(); + /// The same as appendCallAsyncWithSeq, except that it calls C.send() to + /// flush the channel after serializing the call. + template <typename Func, typename... ArgTs> + Expected<AsyncCallWithSeqResult<Func>> + callAsyncWithSeq(ChannelT &C, const ArgTs &... Args) { + auto Result = appendCallAsyncWithSeq<Func>(C, Args...); + if (!Result) + return Result; + if (auto Err = C.send()) { + abandonOutstandingResults(); + return std::move(Err); + } + return Result; + } + + /// Serialize Args... to channel C, but do not call send. + /// Returns an error if serialization fails, otherwise returns a + /// std::future<Optional<T>> (or a future<bool> for void functions). + template <typename Func, typename... ArgTs> + Expected<AsyncCallResult<Func>> appendCallAsync(ChannelT &C, + const ArgTs &... Args) { + auto ResAndSeqOrErr = appendCallAsyncWithSeq<Func>(C, Args...); + if (ResAndSeqOrErr) + return std::move(ResAndSeqOrErr->first); + return ResAndSeqOrErr.getError(); + } + + /// The same as appendCallAsync, except that it calls C.send to flush the + /// channel after serializing the call. + template <typename Func, typename... ArgTs> + Expected<AsyncCallResult<Func>> callAsync(ChannelT &C, + const ArgTs &... Args) { + auto ResAndSeqOrErr = callAsyncWithSeq<Func>(C, Args...); + if (ResAndSeqOrErr) + return std::move(ResAndSeqOrErr->first); + return ResAndSeqOrErr.getError(); + } + + /// This can be used in single-threaded mode. + template <typename Func, typename HandleFtor, typename... ArgTs> + typename Func::ErrorReturn + callSTHandling(ChannelT &C, HandleFtor &HandleOther, const ArgTs &... Args) { + if (auto ResultAndSeqNoOrErr = callAsyncWithSeq<Func>(C, Args...)) { + auto &ResultAndSeqNo = *ResultAndSeqNoOrErr; + if (auto Err = waitForResult(C, ResultAndSeqNo.second, HandleOther)) + return std::move(Err); + return Func::optionalToErrorReturn(ResultAndSeqNo.first.get()); + } else + return ResultAndSeqNoOrErr.takeError(); } - /// Deserialize and return an enum whose underlying type is ProcedureIdT. - static std::error_code getNextProcId(ChannelT &C, ProcedureIdT &Id) { + // This can be used in single-threaded mode. + template <typename Func, typename... ArgTs> + typename Func::ErrorReturn callST(ChannelT &C, const ArgTs &... Args) { + return callSTHandling<Func>(C, handleNone, Args...); + } + + /// Start receiving a new function call. + /// + /// Calls startReceiveMessage on the channel, then deserializes a FunctionId + /// into Id. + Error startReceivingFunction(ChannelT &C, FunctionIdT &Id) { + if (auto Err = startReceiveMessage(C)) + return Err; + return deserialize(C, Id); } - /// Deserialize args for Proc from C and call Handler. The signature of - /// handler must conform to 'std::error_code(Args...)' where Args... matches - /// the arguments used in the Proc typedef. - template <typename Proc, typename HandlerT> - static std::error_code handle(ChannelT &C, HandlerT Handler) { - return HandlerHelper<ChannelT, Proc>::handle(C, Handler); + /// Deserialize args for Func from C and call Handler. The signature of + /// handler must conform to 'Error(Args...)' where Args... matches + /// the arguments used in the Func typedef. + template <typename Func, typename HandlerT> + static Error handle(ChannelT &C, HandlerT Handler) { + return HandlerHelper<ChannelT, SequenceNumberT, Func>::handle(C, Handler); } /// Helper version of 'handle' for calling member functions. - template <typename Proc, typename ClassT, typename... ArgTs> - static std::error_code - handle(ChannelT &C, ClassT &Instance, - std::error_code (ClassT::*HandlerMethod)(ArgTs...)) { - return handle<Proc>( - C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod)); + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + static Error handle(ChannelT &C, ClassT &Instance, + RetT (ClassT::*HandlerMethod)(ArgTs...)) { + return handle<Func>( + C, MemberFnWrapper<ClassT, RetT, ArgTs...>(Instance, HandlerMethod)); } - /// Deserialize a ProcedureIdT from C and verify it matches the id for Proc. + /// Deserialize a FunctionIdT from C and verify it matches the id for Func. /// If the id does match, deserialize the arguments and call the handler /// (similarly to handle). /// If the id does not match, return an unexpect RPC call error and do not /// deserialize any further bytes. - template <typename Proc, typename HandlerT> - static std::error_code expect(ChannelT &C, HandlerT Handler) { - ProcedureIdT ProcId; - if (auto EC = getNextProcId(C, ProcId)) - return EC; - if (ProcId != Proc::Id) + template <typename Func, typename HandlerT> + Error expect(ChannelT &C, HandlerT Handler) { + FunctionIdT FuncId; + if (auto Err = startReceivingFunction(C, FuncId)) + return std::move(Err); + if (FuncId != Func::Id) return orcError(OrcErrorCode::UnexpectedRPCCall); - return handle<Proc>(C, Handler); + return handle<Func>(C, Handler); } /// Helper version of expect for calling member functions. - template <typename Proc, typename ClassT, typename... ArgTs> - static std::error_code - expect(ChannelT &C, ClassT &Instance, - std::error_code (ClassT::*HandlerMethod)(ArgTs...)) { - return expect<Proc>( + template <typename Func, typename ClassT, typename... ArgTs> + static Error expect(ChannelT &C, ClassT &Instance, + Error (ClassT::*HandlerMethod)(ArgTs...)) { + return expect<Func>( C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod)); } @@ -245,18 +526,165 @@ public: /// channel. /// E.g. /// - /// typedef Procedure<0, bool, int> Proc1; + /// typedef Function<0, bool, int> Func1; /// /// ... /// bool B; /// int I; - /// if (auto EC = expect<Proc1>(Channel, readArgs(B, I))) + /// if (auto Err = expect<Func1>(Channel, readArgs(B, I))) /// /* Handle Args */ ; /// template <typename... ArgTs> static ReadArgs<ArgTs...> readArgs(ArgTs &... Args) { return ReadArgs<ArgTs...>(Args...); } + + /// Read a response from Channel. + /// This should be called from the receive loop to retrieve results. + Error handleResponse(ChannelT &C, SequenceNumberT *SeqNoRet = nullptr) { + SequenceNumberT SeqNo; + if (auto Err = deserialize(C, SeqNo)) { + abandonOutstandingResults(); + return Err; + } + + if (SeqNoRet) + *SeqNoRet = SeqNo; + + auto I = OutstandingResults.find(SeqNo); + if (I == OutstandingResults.end()) { + abandonOutstandingResults(); + return orcError(OrcErrorCode::UnexpectedRPCResponse); + } + + if (auto Err = I->second->readResult(C)) { + abandonOutstandingResults(); + // FIXME: Release sequence numbers? + return Err; + } + + OutstandingResults.erase(I); + SequenceNumberMgr.releaseSequenceNumber(SeqNo); + + return Error::success(); + } + + // Loop waiting for a result with the given sequence number. + // This can be used as a receive loop if the user doesn't have a default. + template <typename HandleOtherFtor> + Error waitForResult(ChannelT &C, SequenceNumberT TgtSeqNo, + HandleOtherFtor &HandleOther = handleNone) { + bool GotTgtResult = false; + + while (!GotTgtResult) { + FunctionIdT Id = RPCFunctionIdTraits<FunctionIdT>::InvalidId; + if (auto Err = startReceivingFunction(C, Id)) + return Err; + if (Id == RPCFunctionIdTraits<FunctionIdT>::ResponseId) { + SequenceNumberT SeqNo; + if (auto Err = handleResponse(C, &SeqNo)) + return Err; + GotTgtResult = (SeqNo == TgtSeqNo); + } else if (auto Err = HandleOther(C, Id)) + return Err; + } + + return Error::success(); + } + + // Default handler for 'other' (non-response) functions when waiting for a + // result from the channel. + static Error handleNone(ChannelT &, FunctionIdT) { + return orcError(OrcErrorCode::UnexpectedRPCCall); + }; + +private: + // Manage sequence numbers. + class SequenceNumberManager { + public: + SequenceNumberManager() = default; + + SequenceNumberManager(const SequenceNumberManager &) = delete; + SequenceNumberManager &operator=(const SequenceNumberManager &) = delete; + + SequenceNumberManager(SequenceNumberManager &&Other) + : NextSequenceNumber(std::move(Other.NextSequenceNumber)), + FreeSequenceNumbers(std::move(Other.FreeSequenceNumbers)) {} + + SequenceNumberManager &operator=(SequenceNumberManager &&Other) { + NextSequenceNumber = std::move(Other.NextSequenceNumber); + FreeSequenceNumbers = std::move(Other.FreeSequenceNumbers); + } + + void reset() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + NextSequenceNumber = 0; + FreeSequenceNumbers.clear(); + } + + SequenceNumberT getSequenceNumber() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + if (FreeSequenceNumbers.empty()) + return NextSequenceNumber++; + auto SequenceNumber = FreeSequenceNumbers.back(); + FreeSequenceNumbers.pop_back(); + return SequenceNumber; + } + + void releaseSequenceNumber(SequenceNumberT SequenceNumber) { + std::lock_guard<std::mutex> Lock(SeqNoLock); + FreeSequenceNumbers.push_back(SequenceNumber); + } + + private: + std::mutex SeqNoLock; + SequenceNumberT NextSequenceNumber = 0; + std::vector<SequenceNumberT> FreeSequenceNumbers; + }; + + // Base class for results that haven't been returned from the other end of the + // RPC connection yet. + class OutstandingResult { + public: + virtual ~OutstandingResult() {} + virtual Error readResult(ChannelT &C) = 0; + virtual void abandon() = 0; + }; + + // Outstanding results for a specific function. + template <typename Func> + class OutstandingResultImpl : public OutstandingResult { + private: + public: + OutstandingResultImpl(std::promise<typename Func::OptionalReturn> &&P) + : P(std::move(P)) {} + + Error readResult(ChannelT &C) override { return Func::readResult(C, P); } + + void abandon() override { Func::abandon(P); } + + private: + std::promise<typename Func::OptionalReturn> P; + }; + + // Create an outstanding result for the given function. + template <typename Func> + std::unique_ptr<OutstandingResult> + createOutstandingResult(std::promise<typename Func::OptionalReturn> &&P) { + return llvm::make_unique<OutstandingResultImpl<Func>>(std::move(P)); + } + + // Abandon all outstanding results. + void abandonOutstandingResults() { + for (auto &KV : OutstandingResults) + KV.second->abandon(); + OutstandingResults.clear(); + SequenceNumberMgr.reset(); + } + + SequenceNumberManager SequenceNumberMgr; + std::map<SequenceNumberT, std::unique_ptr<OutstandingResult>> + OutstandingResults; }; } // end namespace remote |