summaryrefslogtreecommitdiff
path: root/include/llvm/ExecutionEngine/Orc
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2017-01-09 21:23:09 +0000
committerDimitry Andric <dim@FreeBSD.org>2017-01-09 21:23:09 +0000
commit909545a822eef491158f831688066f0ec2866938 (patch)
tree5b0bf0e81294007a9b462b21031b3df272c655c3 /include/llvm/ExecutionEngine/Orc
parent7e7b6700743285c0af506ac6299ddf82ebd434b9 (diff)
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc')
-rw-r--r--include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h4
-rw-r--r--include/llvm/ExecutionEngine/Orc/RPCUtils.h246
-rw-r--r--include/llvm/ExecutionEngine/Orc/RawByteChannel.h4
3 files changed, 181 insertions, 73 deletions
diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h
index ab2b0fad89fd6..3086ef0cdf803 100644
--- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h
+++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h
@@ -83,7 +83,7 @@ public:
namespace remote {
class OrcRemoteTargetRPCAPI
- : public rpc::SingleThreadedRPC<rpc::RawByteChannel> {
+ : public rpc::SingleThreadedRPCEndpoint<rpc::RawByteChannel> {
protected:
class ResourceIdMgr {
public:
@@ -108,7 +108,7 @@ protected:
public:
// FIXME: Remove constructors once MSVC supports synthesizing move-ops.
OrcRemoteTargetRPCAPI(rpc::RawByteChannel &C)
- : rpc::SingleThreadedRPC<rpc::RawByteChannel>(C, true) {}
+ : rpc::SingleThreadedRPCEndpoint<rpc::RawByteChannel>(C, true) {}
class CallIntVoid
: public rpc::Function<CallIntVoid, int32_t(JITTargetAddress Addr)> {
diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h
index f51fbe153a41c..37e2e66e5af48 100644
--- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h
+++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h
@@ -702,7 +702,7 @@ public:
/// sync.
template <typename ImplT, typename ChannelT, typename FunctionIdT,
typename SequenceNumberT>
-class RPCBase {
+class RPCEndpointBase {
protected:
class OrcRPCInvalid : public Function<OrcRPCInvalid, void()> {
public:
@@ -747,7 +747,7 @@ protected:
public:
/// Construct an RPC instance on a channel.
- RPCBase(ChannelT &C, bool LazyAutoNegotiation)
+ RPCEndpointBase(ChannelT &C, bool LazyAutoNegotiation)
: C(C), LazyAutoNegotiation(LazyAutoNegotiation) {
// Hold ResponseId in a special variable, since we expect Response to be
// called relatively frequently, and want to avoid the map lookup.
@@ -788,15 +788,21 @@ public:
return FnIdOrErr.takeError();
}
- // Allocate a sequence number.
- auto SeqNo = SequenceNumberMgr.getSequenceNumber();
- assert(!PendingResponses.count(SeqNo) &&
- "Sequence number already allocated");
+ SequenceNumberT SeqNo; // initialized in locked scope below.
+ {
+ // Lock the pending responses map and sequence number manager.
+ std::lock_guard<std::mutex> Lock(ResponsesMutex);
+
+ // Allocate a sequence number.
+ SeqNo = SequenceNumberMgr.getSequenceNumber();
+ assert(!PendingResponses.count(SeqNo) &&
+ "Sequence number already allocated");
- // Install the user handler.
- PendingResponses[SeqNo] =
+ // Install the user handler.
+ PendingResponses[SeqNo] =
detail::createResponseHandler<ChannelT, typename Func::ReturnType>(
std::move(Handler));
+ }
// Open the function call message.
if (auto Err = C.startSendMessage(FnId, SeqNo)) {
@@ -863,11 +869,33 @@ public:
return detail::ReadArgs<ArgTs...>(Args...);
}
+ /// Abandon all outstanding result handlers.
+ ///
+ /// This will call all currently registered result handlers to receive an
+ /// "abandoned" error as their argument. This is used internally by the RPC
+ /// in error situations, but can also be called directly by clients who are
+ /// disconnecting from the remote and don't or can't expect responses to their
+ /// outstanding calls. (Especially for outstanding blocking calls, calling
+ /// this function may be necessary to avoid dead threads).
+ void abandonPendingResponses() {
+ // Lock the pending responses map and sequence number manager.
+ std::lock_guard<std::mutex> Lock(ResponsesMutex);
+
+ for (auto &KV : PendingResponses)
+ KV.second->abandon();
+ PendingResponses.clear();
+ SequenceNumberMgr.reset();
+ }
+
protected:
// The LaunchPolicy type allows a launch policy to be specified when adding
// a function handler. See addHandlerImpl.
using LaunchPolicy = std::function<Error(std::function<Error()>)>;
+ FunctionIdT getInvalidFunctionId() const {
+ return FnIdAllocator.getInvalidId();
+ }
+
/// Add the given handler to the handler map and make it available for
/// autonegotiation and execution.
template <typename Func, typename HandlerT>
@@ -884,28 +912,32 @@ protected:
wrapHandler<Func>(std::move(Handler), std::move(Launch));
}
- // Abandon all outstanding results.
- void abandonPendingResponses() {
- for (auto &KV : PendingResponses)
- KV.second->abandon();
- PendingResponses.clear();
- SequenceNumberMgr.reset();
- }
-
Error handleResponse(SequenceNumberT SeqNo) {
- auto I = PendingResponses.find(SeqNo);
- if (I == PendingResponses.end()) {
- abandonPendingResponses();
- return orcError(OrcErrorCode::UnexpectedRPCResponse);
+ using Handler = typename decltype(PendingResponses)::mapped_type;
+ Handler PRHandler;
+
+ {
+ // Lock the pending responses map and sequence number manager.
+ std::unique_lock<std::mutex> Lock(ResponsesMutex);
+ auto I = PendingResponses.find(SeqNo);
+
+ if (I != PendingResponses.end()) {
+ PRHandler = std::move(I->second);
+ PendingResponses.erase(I);
+ SequenceNumberMgr.releaseSequenceNumber(SeqNo);
+ } else {
+ // Unlock the pending results map to prevent recursive lock.
+ Lock.unlock();
+ abandonPendingResponses();
+ return orcError(OrcErrorCode::UnexpectedRPCResponse);
+ }
}
- auto PRHandler = std::move(I->second);
- PendingResponses.erase(I);
- SequenceNumberMgr.releaseSequenceNumber(SeqNo);
+ assert(PRHandler &&
+ "If we didn't find a response handler we should have bailed out");
if (auto Err = PRHandler->handleResponse(C)) {
abandonPendingResponses();
- SequenceNumberMgr.reset();
return Err;
}
@@ -915,7 +947,7 @@ protected:
FunctionIdT handleNegotiate(const std::string &Name) {
auto I = LocalFunctionIds.find(Name);
if (I == LocalFunctionIds.end())
- return FnIdAllocator.getInvalidId();
+ return getInvalidFunctionId();
return I->second;
}
@@ -938,7 +970,7 @@ protected:
// If autonegotiation indicates that the remote end doesn't support this
// function, return an unknown function error.
- if (RemoteId == FnIdAllocator.getInvalidId())
+ if (RemoteId == getInvalidFunctionId())
return orcError(OrcErrorCode::UnknownRPCFunction);
// Autonegotiation succeeded and returned a valid id. Update the map and
@@ -1012,6 +1044,7 @@ protected:
std::map<FunctionIdT, WrappedHandlerFn> Handlers;
+ std::mutex ResponsesMutex;
detail::SequenceNumberManager<SequenceNumberT> SequenceNumberMgr;
std::map<SequenceNumberT, std::unique_ptr<detail::ResponseHandler<ChannelT>>>
PendingResponses;
@@ -1021,17 +1054,18 @@ protected:
template <typename ChannelT, typename FunctionIdT = uint32_t,
typename SequenceNumberT = uint32_t>
-class MultiThreadedRPC
- : public detail::RPCBase<
- MultiThreadedRPC<ChannelT, FunctionIdT, SequenceNumberT>, ChannelT,
- FunctionIdT, SequenceNumberT> {
+class MultiThreadedRPCEndpoint
+ : public detail::RPCEndpointBase<
+ MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
+ ChannelT, FunctionIdT, SequenceNumberT> {
private:
using BaseClass =
- detail::RPCBase<MultiThreadedRPC<ChannelT, FunctionIdT, SequenceNumberT>,
- ChannelT, FunctionIdT, SequenceNumberT>;
+ detail::RPCEndpointBase<
+ MultiThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
+ ChannelT, FunctionIdT, SequenceNumberT>;
public:
- MultiThreadedRPC(ChannelT &C, bool LazyAutoNegotiation)
+ MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
: BaseClass(C, LazyAutoNegotiation) {}
/// The LaunchPolicy type allows a launch policy to be specified when adding
@@ -1061,30 +1095,41 @@ public:
std::move(Launch));
}
+ /// Add a class-method as a handler.
+ template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
+ void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...),
+ LaunchPolicy Launch = LaunchPolicy()) {
+ addHandler<Func>(
+ detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method),
+ Launch);
+ }
+
/// Negotiate a function id for Func with the other end of the channel.
- template <typename Func> Error negotiateFunction() {
+ template <typename Func> Error negotiateFunction(bool Retry = false) {
using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate;
+ // Check if we already have a function id...
+ auto I = this->RemoteFunctionIds.find(Func::getPrototype());
+ if (I != this->RemoteFunctionIds.end()) {
+ // If it's valid there's nothing left to do.
+ if (I->second != this->getInvalidFunctionId())
+ return Error::success();
+ // If it's invalid and we can't re-attempt negotiation, throw an error.
+ if (!Retry)
+ return orcError(OrcErrorCode::UnknownRPCFunction);
+ }
+
+ // We don't have a function id for Func yet, call the remote to try to
+ // negotiate one.
if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) {
this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
+ if (*RemoteIdOrErr == this->getInvalidFunctionId())
+ return orcError(OrcErrorCode::UnknownRPCFunction);
return Error::success();
} else
return RemoteIdOrErr.takeError();
}
- /// Convenience method for negotiating multiple functions at once.
- template <typename Func> Error negotiateFunctions() {
- return negotiateFunction<Func>();
- }
-
- /// Convenience method for negotiating multiple functions at once.
- template <typename Func1, typename Func2, typename... Funcs>
- Error negotiateFunctions() {
- if (auto Err = negotiateFunction<Func1>())
- return Err;
- return negotiateFunctions<Func2, Funcs...>();
- }
-
/// Return type for non-blocking call primitives.
template <typename Func>
using NonBlockingCallResult = typename detail::ResultTraits<
@@ -1169,19 +1214,20 @@ public:
template <typename ChannelT, typename FunctionIdT = uint32_t,
typename SequenceNumberT = uint32_t>
-class SingleThreadedRPC
- : public detail::RPCBase<
- SingleThreadedRPC<ChannelT, FunctionIdT, SequenceNumberT>, ChannelT,
- FunctionIdT, SequenceNumberT> {
+class SingleThreadedRPCEndpoint
+ : public detail::RPCEndpointBase<
+ SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
+ ChannelT, FunctionIdT, SequenceNumberT> {
private:
using BaseClass =
- detail::RPCBase<SingleThreadedRPC<ChannelT, FunctionIdT, SequenceNumberT>,
- ChannelT, FunctionIdT, SequenceNumberT>;
+ detail::RPCEndpointBase<
+ SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
+ ChannelT, FunctionIdT, SequenceNumberT>;
using LaunchPolicy = typename BaseClass::LaunchPolicy;
public:
- SingleThreadedRPC(ChannelT &C, bool LazyAutoNegotiation)
+ SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
: BaseClass(C, LazyAutoNegotiation) {}
template <typename Func, typename HandlerT>
@@ -1197,29 +1243,31 @@ public:
}
/// Negotiate a function id for Func with the other end of the channel.
- template <typename Func> Error negotiateFunction() {
+ template <typename Func> Error negotiateFunction(bool Retry = false) {
using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate;
+ // Check if we already have a function id...
+ auto I = this->RemoteFunctionIds.find(Func::getPrototype());
+ if (I != this->RemoteFunctionIds.end()) {
+ // If it's valid there's nothing left to do.
+ if (I->second != this->getInvalidFunctionId())
+ return Error::success();
+ // If it's invalid and we can't re-attempt negotiation, throw an error.
+ if (!Retry)
+ return orcError(OrcErrorCode::UnknownRPCFunction);
+ }
+
+ // We don't have a function id for Func yet, call the remote to try to
+ // negotiate one.
if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) {
this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
+ if (*RemoteIdOrErr == this->getInvalidFunctionId())
+ return orcError(OrcErrorCode::UnknownRPCFunction);
return Error::success();
} else
return RemoteIdOrErr.takeError();
}
- /// Convenience method for negotiating multiple functions at once.
- template <typename Func> Error negotiateFunctions() {
- return negotiateFunction<Func>();
- }
-
- /// Convenience method for negotiating multiple functions at once.
- template <typename Func1, typename Func2, typename... Funcs>
- Error negotiateFunctions() {
- if (auto Err = negotiateFunction<Func1>())
- return Err;
- return negotiateFunctions<Func2, Funcs...>();
- }
-
template <typename Func, typename... ArgTs,
typename AltRetT = typename Func::ReturnType>
typename detail::ResultTraits<AltRetT>::ErrorReturnType
@@ -1332,6 +1380,68 @@ private:
uint32_t NumOutstandingCalls;
};
+/// @brief Convenience class for grouping RPC Functions into APIs that can be
+/// negotiated as a block.
+///
+template <typename... Funcs>
+class APICalls {
+public:
+
+ /// @brief Test whether this API contains Function F.
+ template <typename F>
+ class Contains {
+ public:
+ static const bool value = false;
+ };
+
+ /// @brief Negotiate all functions in this API.
+ template <typename RPCEndpoint>
+ static Error negotiate(RPCEndpoint &R) {
+ return Error::success();
+ }
+};
+
+template <typename Func, typename... Funcs>
+class APICalls<Func, Funcs...> {
+public:
+
+ template <typename F>
+ class Contains {
+ public:
+ static const bool value = std::is_same<F, Func>::value |
+ APICalls<Funcs...>::template Contains<F>::value;
+ };
+
+ template <typename RPCEndpoint>
+ static Error negotiate(RPCEndpoint &R) {
+ if (auto Err = R.template negotiateFunction<Func>())
+ return Err;
+ return APICalls<Funcs...>::negotiate(R);
+ }
+
+};
+
+template <typename... InnerFuncs, typename... Funcs>
+class APICalls<APICalls<InnerFuncs...>, Funcs...> {
+public:
+
+ template <typename F>
+ class Contains {
+ public:
+ static const bool value =
+ APICalls<InnerFuncs...>::template Contains<F>::value |
+ APICalls<Funcs...>::template Contains<F>::value;
+ };
+
+ template <typename RPCEndpoint>
+ static Error negotiate(RPCEndpoint &R) {
+ if (auto Err = APICalls<InnerFuncs...>::negotiate(R))
+ return Err;
+ return APICalls<Funcs...>::negotiate(R);
+ }
+
+};
+
} // end namespace rpc
} // end namespace orc
} // end namespace llvm
diff --git a/include/llvm/ExecutionEngine/Orc/RawByteChannel.h b/include/llvm/ExecutionEngine/Orc/RawByteChannel.h
index 83a7b9a844f2a..3b6c84eb19650 100644
--- a/include/llvm/ExecutionEngine/Orc/RawByteChannel.h
+++ b/include/llvm/ExecutionEngine/Orc/RawByteChannel.h
@@ -48,9 +48,7 @@ public:
template <typename FunctionIdT, typename SequenceIdT>
Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
writeLock.lock();
- if (auto Err = serializeSeq(*this, FnId, SeqNo))
- return Err;
- return Error::success();
+ return serializeSeq(*this, FnId, SeqNo);
}
/// Notify the channel that we're ending a message send.