diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2017-01-09 21:23:09 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2017-01-09 21:23:09 +0000 |
commit | 909545a822eef491158f831688066f0ec2866938 (patch) | |
tree | 5b0bf0e81294007a9b462b21031b3df272c655c3 /include/llvm/ExecutionEngine/Orc | |
parent | 7e7b6700743285c0af506ac6299ddf82ebd434b9 (diff) |
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc')
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h | 4 | ||||
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/RPCUtils.h | 246 | ||||
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/RawByteChannel.h | 4 |
3 files changed, 181 insertions, 73 deletions
diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h index ab2b0fad89fd6..3086ef0cdf803 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h @@ -83,7 +83,7 @@ public: namespace remote { class OrcRemoteTargetRPCAPI - : public rpc::SingleThreadedRPC<rpc::RawByteChannel> { + : public rpc::SingleThreadedRPCEndpoint<rpc::RawByteChannel> { protected: class ResourceIdMgr { public: @@ -108,7 +108,7 @@ protected: public: // FIXME: Remove constructors once MSVC supports synthesizing move-ops. OrcRemoteTargetRPCAPI(rpc::RawByteChannel &C) - : rpc::SingleThreadedRPC<rpc::RawByteChannel>(C, true) {} + : rpc::SingleThreadedRPCEndpoint<rpc::RawByteChannel>(C, true) {} class CallIntVoid : public rpc::Function<CallIntVoid, int32_t(JITTargetAddress Addr)> { diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h index f51fbe153a41c..37e2e66e5af48 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -702,7 +702,7 @@ public: /// sync. template <typename ImplT, typename ChannelT, typename FunctionIdT, typename SequenceNumberT> -class RPCBase { +class RPCEndpointBase { protected: class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> { public: @@ -747,7 +747,7 @@ protected: public: /// Construct an RPC instance on a channel. - RPCBase(ChannelT &C, bool LazyAutoNegotiation) + RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation) : C(C), LazyAutoNegotiation(LazyAutoNegotiation) { // Hold ResponseId in a special variable, since we expect Response to be // called relatively frequently, and want to avoid the map lookup. @@ -788,15 +788,21 @@ public: return FnIdOrErr.takeError(); } - // Allocate a sequence number. - auto SeqNo = SequenceNumberMgr.getSequenceNumber(); - assert(!PendingResponses.count(SeqNo) && - "Sequence number already allocated"); + SequenceNumberT SeqNo; // initialized in locked scope below. + { + // Lock the pending responses map and sequence number manager. + std::lock_guard<std::mutex> Lock(ResponsesMutex); + + // Allocate a sequence number. + SeqNo = SequenceNumberMgr.getSequenceNumber(); + assert(!PendingResponses.count(SeqNo) && + "Sequence number already allocated"); - // Install the user handler. - PendingResponses[SeqNo] = + // Install the user handler. + PendingResponses[SeqNo] = detail::createResponseHandler<ChannelT, typename Func::ReturnType>( std::move(Handler)); + } // Open the function call message. if (auto Err = C.startSendMessage(FnId, SeqNo)) { @@ -863,11 +869,33 @@ public: return detail::ReadArgs<ArgTs...>(Args...); } + /// Abandon all outstanding result handlers. + /// + /// This will call all currently registered result handlers to receive an + /// "abandoned" error as their argument. This is used internally by the RPC + /// in error situations, but can also be called directly by clients who are + /// disconnecting from the remote and don't or can't expect responses to their + /// outstanding calls. (Especially for outstanding blocking calls, calling + /// this function may be necessary to avoid dead threads). + void abandonPendingResponses() { + // Lock the pending responses map and sequence number manager. + std::lock_guard<std::mutex> Lock(ResponsesMutex); + + for (auto &KV : PendingResponses) + KV.second->abandon(); + PendingResponses.clear(); + SequenceNumberMgr.reset(); + } + protected: // The LaunchPolicy type allows a launch policy to be specified when adding // a function handler. See addHandlerImpl. using LaunchPolicy = std::function<Error(std::function<Error()>)>; + FunctionIdT getInvalidFunctionId() const { + return FnIdAllocator.getInvalidId(); + } + /// Add the given handler to the handler map and make it available for /// autonegotiation and execution. template <typename Func, typename HandlerT> @@ -884,28 +912,32 @@ protected: wrapHandler<Func>(std::move(Handler), std::move(Launch)); } - // Abandon all outstanding results. - void abandonPendingResponses() { - for (auto &KV : PendingResponses) - KV.second->abandon(); - PendingResponses.clear(); - SequenceNumberMgr.reset(); - } - Error handleResponse(SequenceNumberT SeqNo) { - auto I = PendingResponses.find(SeqNo); - if (I == PendingResponses.end()) { - abandonPendingResponses(); - return orcError(OrcErrorCode::UnexpectedRPCResponse); + using Handler = typename decltype(PendingResponses)::mapped_type; + Handler PRHandler; + + { + // Lock the pending responses map and sequence number manager. + std::unique_lock<std::mutex> Lock(ResponsesMutex); + auto I = PendingResponses.find(SeqNo); + + if (I != PendingResponses.end()) { + PRHandler = std::move(I->second); + PendingResponses.erase(I); + SequenceNumberMgr.releaseSequenceNumber(SeqNo); + } else { + // Unlock the pending results map to prevent recursive lock. + Lock.unlock(); + abandonPendingResponses(); + return orcError(OrcErrorCode::UnexpectedRPCResponse); + } } - auto PRHandler = std::move(I->second); - PendingResponses.erase(I); - SequenceNumberMgr.releaseSequenceNumber(SeqNo); + assert(PRHandler && + "If we didn't find a response handler we should have bailed out"); if (auto Err = PRHandler->handleResponse(C)) { abandonPendingResponses(); - SequenceNumberMgr.reset(); return Err; } @@ -915,7 +947,7 @@ protected: FunctionIdT handleNegotiate(const std::string &Name) { auto I = LocalFunctionIds.find(Name); if (I == LocalFunctionIds.end()) - return FnIdAllocator.getInvalidId(); + return getInvalidFunctionId(); return I->second; } @@ -938,7 +970,7 @@ protected: // If autonegotiation indicates that the remote end doesn't support this // function, return an unknown function error. - if (RemoteId == FnIdAllocator.getInvalidId()) + if (RemoteId == getInvalidFunctionId()) return orcError(OrcErrorCode::UnknownRPCFunction); // Autonegotiation succeeded and returned a valid id. Update the map and @@ -1012,6 +1044,7 @@ protected: std::map<FunctionIdT, WrappedHandlerFn> Handlers; + std::mutex ResponsesMutex; detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr; std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>> PendingResponses; @@ -1021,17 +1054,18 @@ protected: template <typename ChannelT, typename FunctionIdT = uint32_t, typename SequenceNumberT = uint32_t> -class MultiThreadedRPC - : public detail::RPCBase< - MultiThreadedRPC<ChannelT, FunctionIdT, SequenceNumberT>, ChannelT, - FunctionIdT, SequenceNumberT> { +class MultiThreadedRPCEndpoint + : public detail::RPCEndpointBase< + MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT> { private: using BaseClass = - detail::RPCBase<MultiThreadedRPC<ChannelT, FunctionIdT, SequenceNumberT>, - ChannelT, FunctionIdT, SequenceNumberT>; + detail::RPCEndpointBase< + MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT>; public: - MultiThreadedRPC(ChannelT &C, bool LazyAutoNegotiation) + MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) : BaseClass(C, LazyAutoNegotiation) {} /// The LaunchPolicy type allows a launch policy to be specified when adding @@ -1061,30 +1095,41 @@ public: std::move(Launch)); } + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...), + LaunchPolicy Launch = LaunchPolicy()) { + addHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method), + Launch); + } + /// Negotiate a function id for Func with the other end of the channel. - template <typename Func> Error negotiateFunction() { + template <typename Func> Error negotiateFunction(bool Retry = false) { using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; + // Check if we already have a function id... + auto I = this->RemoteFunctionIds.find(Func::getPrototype()); + if (I != this->RemoteFunctionIds.end()) { + // If it's valid there's nothing left to do. + if (I->second != this->getInvalidFunctionId()) + return Error::success(); + // If it's invalid and we can't re-attempt negotiation, throw an error. + if (!Retry) + return orcError(OrcErrorCode::UnknownRPCFunction); + } + + // We don't have a function id for Func yet, call the remote to try to + // negotiate one. if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) { this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + if (*RemoteIdOrErr == this->getInvalidFunctionId()) + return orcError(OrcErrorCode::UnknownRPCFunction); return Error::success(); } else return RemoteIdOrErr.takeError(); } - /// Convenience method for negotiating multiple functions at once. - template <typename Func> Error negotiateFunctions() { - return negotiateFunction<Func>(); - } - - /// Convenience method for negotiating multiple functions at once. - template <typename Func1, typename Func2, typename... Funcs> - Error negotiateFunctions() { - if (auto Err = negotiateFunction<Func1>()) - return Err; - return negotiateFunctions<Func2, Funcs...>(); - } - /// Return type for non-blocking call primitives. template <typename Func> using NonBlockingCallResult = typename detail::ResultTraits< @@ -1169,19 +1214,20 @@ public: template <typename ChannelT, typename FunctionIdT = uint32_t, typename SequenceNumberT = uint32_t> -class SingleThreadedRPC - : public detail::RPCBase< - SingleThreadedRPC<ChannelT, FunctionIdT, SequenceNumberT>, ChannelT, - FunctionIdT, SequenceNumberT> { +class SingleThreadedRPCEndpoint + : public detail::RPCEndpointBase< + SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT> { private: using BaseClass = - detail::RPCBase<SingleThreadedRPC<ChannelT, FunctionIdT, SequenceNumberT>, - ChannelT, FunctionIdT, SequenceNumberT>; + detail::RPCEndpointBase< + SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, + ChannelT, FunctionIdT, SequenceNumberT>; using LaunchPolicy = typename BaseClass::LaunchPolicy; public: - SingleThreadedRPC(ChannelT &C, bool LazyAutoNegotiation) + SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) : BaseClass(C, LazyAutoNegotiation) {} template <typename Func, typename HandlerT> @@ -1197,29 +1243,31 @@ public: } /// Negotiate a function id for Func with the other end of the channel. - template <typename Func> Error negotiateFunction() { + template <typename Func> Error negotiateFunction(bool Retry = false) { using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; + // Check if we already have a function id... + auto I = this->RemoteFunctionIds.find(Func::getPrototype()); + if (I != this->RemoteFunctionIds.end()) { + // If it's valid there's nothing left to do. + if (I->second != this->getInvalidFunctionId()) + return Error::success(); + // If it's invalid and we can't re-attempt negotiation, throw an error. + if (!Retry) + return orcError(OrcErrorCode::UnknownRPCFunction); + } + + // We don't have a function id for Func yet, call the remote to try to + // negotiate one. if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) { this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + if (*RemoteIdOrErr == this->getInvalidFunctionId()) + return orcError(OrcErrorCode::UnknownRPCFunction); return Error::success(); } else return RemoteIdOrErr.takeError(); } - /// Convenience method for negotiating multiple functions at once. - template <typename Func> Error negotiateFunctions() { - return negotiateFunction<Func>(); - } - - /// Convenience method for negotiating multiple functions at once. - template <typename Func1, typename Func2, typename... Funcs> - Error negotiateFunctions() { - if (auto Err = negotiateFunction<Func1>()) - return Err; - return negotiateFunctions<Func2, Funcs...>(); - } - template <typename Func, typename... ArgTs, typename AltRetT = typename Func::ReturnType> typename detail::ResultTraits<AltRetT>::ErrorReturnType @@ -1332,6 +1380,68 @@ private: uint32_t NumOutstandingCalls; }; +/// @brief Convenience class for grouping RPC Functions into APIs that can be +/// negotiated as a block. +/// +template <typename... Funcs> +class APICalls { +public: + + /// @brief Test whether this API contains Function F. + template <typename F> + class Contains { + public: + static const bool value = false; + }; + + /// @brief Negotiate all functions in this API. + template <typename RPCEndpoint> + static Error negotiate(RPCEndpoint &R) { + return Error::success(); + } +}; + +template <typename Func, typename... Funcs> +class APICalls<Func, Funcs...> { +public: + + template <typename F> + class Contains { + public: + static const bool value = std::is_same<F, Func>::value | + APICalls<Funcs...>::template Contains<F>::value; + }; + + template <typename RPCEndpoint> + static Error negotiate(RPCEndpoint &R) { + if (auto Err = R.template negotiateFunction<Func>()) + return Err; + return APICalls<Funcs...>::negotiate(R); + } + +}; + +template <typename... InnerFuncs, typename... Funcs> +class APICalls<APICalls<InnerFuncs...>, Funcs...> { +public: + + template <typename F> + class Contains { + public: + static const bool value = + APICalls<InnerFuncs...>::template Contains<F>::value | + APICalls<Funcs...>::template Contains<F>::value; + }; + + template <typename RPCEndpoint> + static Error negotiate(RPCEndpoint &R) { + if (auto Err = APICalls<InnerFuncs...>::negotiate(R)) + return Err; + return APICalls<Funcs...>::negotiate(R); + } + +}; + } // end namespace rpc } // end namespace orc } // end namespace llvm diff --git a/include/llvm/ExecutionEngine/Orc/RawByteChannel.h b/include/llvm/ExecutionEngine/Orc/RawByteChannel.h index 83a7b9a844f2a..3b6c84eb19650 100644 --- a/include/llvm/ExecutionEngine/Orc/RawByteChannel.h +++ b/include/llvm/ExecutionEngine/Orc/RawByteChannel.h @@ -48,9 +48,7 @@ public: template <typename FunctionIdT, typename SequenceIdT> Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { writeLock.lock(); - if (auto Err = serializeSeq(*this, FnId, SeqNo)) - return Err; - return Error::success(); + return serializeSeq(*this, FnId, SeqNo); } /// Notify the channel that we're ending a message send. |