summaryrefslogtreecommitdiff
path: root/include/llvm/ExecutionEngine/Orc
diff options
context:
space:
mode:
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc')
-rw-r--r--include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h12
-rw-r--r--include/llvm/ExecutionEngine/Orc/OrcError.h6
-rw-r--r--include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h1
-rw-r--r--include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h23
-rw-r--r--include/llvm/ExecutionEngine/Orc/RPCSerialization.h243
-rw-r--r--include/llvm/ExecutionEngine/Orc/RPCUtils.h787
-rw-r--r--include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h (renamed from include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h)18
-rw-r--r--include/llvm/ExecutionEngine/Orc/RawByteChannel.h34
8 files changed, 833 insertions, 291 deletions
diff --git a/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h b/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h
index aa096478cd9e7..7e7f7358938a4 100644
--- a/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h
+++ b/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h
@@ -376,7 +376,7 @@ private:
// Initializers may refer to functions declared (but not defined) in this
// module. Build a materializer to clone decls on demand.
auto Materializer = createLambdaMaterializer(
- [this, &LD, &GVsM](Value *V) -> Value* {
+ [&LD, &GVsM](Value *V) -> Value* {
if (auto *F = dyn_cast<Function>(V)) {
// Decls in the original module just get cloned.
if (F->isDeclaration())
@@ -419,7 +419,7 @@ private:
// Build a resolver for the globals module and add it to the base layer.
auto GVsResolver = createLambdaResolver(
- [this, &LD, LMId](const std::string &Name) {
+ [this, &LD](const std::string &Name) {
if (auto Sym = LD.StubsMgr->findStub(Name, false))
return Sym;
if (auto Sym = LD.findSymbol(BaseLayer, Name, false))
@@ -499,8 +499,8 @@ private:
M->setDataLayout(SrcM.getDataLayout());
ValueToValueMapTy VMap;
- auto Materializer = createLambdaMaterializer([this, &LD, &LMId, &M,
- &VMap](Value *V) -> Value * {
+ auto Materializer = createLambdaMaterializer([&LD, &LMId,
+ &M](Value *V) -> Value * {
if (auto *GV = dyn_cast<GlobalVariable>(V))
return cloneGlobalVariableDecl(*M, *GV);
@@ -546,12 +546,12 @@ private:
// Create memory manager and symbol resolver.
auto Resolver = createLambdaResolver(
- [this, &LD, LMId](const std::string &Name) {
+ [this, &LD](const std::string &Name) {
if (auto Sym = LD.findSymbol(BaseLayer, Name, false))
return Sym;
return LD.ExternalSymbolResolver->findSymbolInLogicalDylib(Name);
},
- [this, &LD](const std::string &Name) {
+ [&LD](const std::string &Name) {
return LD.ExternalSymbolResolver->findSymbol(Name);
});
diff --git a/include/llvm/ExecutionEngine/Orc/OrcError.h b/include/llvm/ExecutionEngine/Orc/OrcError.h
index b74988cce2fb0..cbb40fad02230 100644
--- a/include/llvm/ExecutionEngine/Orc/OrcError.h
+++ b/include/llvm/ExecutionEngine/Orc/OrcError.h
@@ -27,13 +27,15 @@ enum class OrcErrorCode : int {
RemoteMProtectAddrUnrecognized,
RemoteIndirectStubsOwnerDoesNotExist,
RemoteIndirectStubsOwnerIdAlreadyInUse,
+ RPCConnectionClosed,
+ RPCCouldNotNegotiateFunction,
RPCResponseAbandoned,
UnexpectedRPCCall,
UnexpectedRPCResponse,
- UnknownRPCFunction
+ UnknownErrorCodeFromRemote
};
-Error orcError(OrcErrorCode ErrCode);
+std::error_code orcError(OrcErrorCode ErrCode);
} // End namespace orc.
} // End namespace llvm.
diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h
index 8647db56cd2f5..02f59d6a831a8 100644
--- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h
+++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h
@@ -18,6 +18,7 @@
#include "IndirectionUtils.h"
#include "OrcRemoteTargetRPCAPI.h"
+#include "llvm/ExecutionEngine/RuntimeDyld.h"
#include <system_error>
#define DEBUG_TYPE "orc-remote"
diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h
index 506330fe3a5e9..a61ff102be0b0 100644
--- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h
+++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h
@@ -132,7 +132,7 @@ private:
Error setProtections(void *block, unsigned Flags) {
auto I = Allocs.find(block);
if (I == Allocs.end())
- return orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized);
+ return errorCodeToError(orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized));
return errorCodeToError(
sys::Memory::protectMappedMemory(I->second, Flags));
}
@@ -198,7 +198,8 @@ private:
Error handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) {
auto I = Allocators.find(Id);
if (I != Allocators.end())
- return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse);
+ return errorCodeToError(
+ orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse));
DEBUG(dbgs() << " Created allocator " << Id << "\n");
Allocators[Id] = Allocator();
return Error::success();
@@ -207,7 +208,8 @@ private:
Error handleCreateIndirectStubsOwner(ResourceIdMgr::ResourceId Id) {
auto I = IndirectStubsOwners.find(Id);
if (I != IndirectStubsOwners.end())
- return orcError(OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse);
+ return errorCodeToError(
+ orcError(OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse));
DEBUG(dbgs() << " Create indirect stubs owner " << Id << "\n");
IndirectStubsOwners[Id] = ISBlockOwnerList();
return Error::success();
@@ -224,7 +226,8 @@ private:
Error handleDestroyRemoteAllocator(ResourceIdMgr::ResourceId Id) {
auto I = Allocators.find(Id);
if (I == Allocators.end())
- return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
+ return errorCodeToError(
+ orcError(OrcErrorCode::RemoteAllocatorDoesNotExist));
Allocators.erase(I);
DEBUG(dbgs() << " Destroyed allocator " << Id << "\n");
return Error::success();
@@ -233,7 +236,8 @@ private:
Error handleDestroyIndirectStubsOwner(ResourceIdMgr::ResourceId Id) {
auto I = IndirectStubsOwners.find(Id);
if (I == IndirectStubsOwners.end())
- return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
+ return errorCodeToError(
+ orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist));
IndirectStubsOwners.erase(I);
return Error::success();
}
@@ -246,7 +250,8 @@ private:
auto StubOwnerItr = IndirectStubsOwners.find(Id);
if (StubOwnerItr == IndirectStubsOwners.end())
- return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
+ return errorCodeToError(
+ orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist));
typename TargetT::IndirectStubsInfo IS;
if (auto Err =
@@ -361,7 +366,8 @@ private:
uint64_t Size, uint32_t Align) {
auto I = Allocators.find(Id);
if (I == Allocators.end())
- return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
+ return errorCodeToError(
+ orcError(OrcErrorCode::RemoteAllocatorDoesNotExist));
auto &Allocator = I->second;
void *LocalAllocAddr = nullptr;
if (auto Err = Allocator.allocate(LocalAllocAddr, Size, Align))
@@ -380,7 +386,8 @@ private:
JITTargetAddress Addr, uint32_t Flags) {
auto I = Allocators.find(Id);
if (I == Allocators.end())
- return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
+ return errorCodeToError(
+ orcError(OrcErrorCode::RemoteAllocatorDoesNotExist));
auto &Allocator = I->second;
void *LocalAddr = reinterpret_cast<void *>(static_cast<uintptr_t>(Addr));
DEBUG(dbgs() << " Allocator " << Id << " set permissions on " << LocalAddr
diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h
index 359a9d81b22b5..84a037b2f998b 100644
--- a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h
+++ b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h
@@ -12,6 +12,7 @@
#include "OrcError.h"
#include "llvm/Support/thread.h"
+#include <map>
#include <mutex>
#include <sstream>
@@ -114,6 +115,35 @@ public:
static const char* getName() { return "std::string"; }
};
+template <>
+class RPCTypeName<Error> {
+public:
+ static const char* getName() { return "Error"; }
+};
+
+template <typename T>
+class RPCTypeName<Expected<T>> {
+public:
+ static const char* getName() {
+ std::lock_guard<std::mutex> Lock(NameMutex);
+ if (Name.empty())
+ raw_string_ostream(Name) << "Expected<"
+ << RPCTypeNameSequence<T>()
+ << ">";
+ return Name.data();
+ }
+
+private:
+ static std::mutex NameMutex;
+ static std::string Name;
+};
+
+template <typename T>
+std::mutex RPCTypeName<Expected<T>>::NameMutex;
+
+template <typename T>
+std::string RPCTypeName<Expected<T>>::Name;
+
template <typename T1, typename T2>
class RPCTypeName<std::pair<T1, T2>> {
public:
@@ -243,8 +273,10 @@ class SequenceSerialization<ChannelT, ArgT> {
public:
template <typename CArgT>
- static Error serialize(ChannelT &C, const CArgT &CArg) {
- return SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg);
+ static Error serialize(ChannelT &C, CArgT &&CArg) {
+ return SerializationTraits<ChannelT, ArgT,
+ typename std::decay<CArgT>::type>::
+ serialize(C, std::forward<CArgT>(CArg));
}
template <typename CArgT>
@@ -258,19 +290,21 @@ class SequenceSerialization<ChannelT, ArgT, ArgTs...> {
public:
template <typename CArgT, typename... CArgTs>
- static Error serialize(ChannelT &C, const CArgT &CArg,
- const CArgTs&... CArgs) {
+ static Error serialize(ChannelT &C, CArgT &&CArg,
+ CArgTs &&... CArgs) {
if (auto Err =
- SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg))
+ SerializationTraits<ChannelT, ArgT, typename std::decay<CArgT>::type>::
+ serialize(C, std::forward<CArgT>(CArg)))
return Err;
if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C))
return Err;
- return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...);
+ return SequenceSerialization<ChannelT, ArgTs...>::
+ serialize(C, std::forward<CArgTs>(CArgs)...);
}
template <typename CArgT, typename... CArgTs>
static Error deserialize(ChannelT &C, CArgT &CArg,
- CArgTs&... CArgs) {
+ CArgTs &... CArgs) {
if (auto Err =
SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg))
return Err;
@@ -281,8 +315,9 @@ public:
};
template <typename ChannelT, typename... ArgTs>
-Error serializeSeq(ChannelT &C, const ArgTs &... Args) {
- return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, Args...);
+Error serializeSeq(ChannelT &C, ArgTs &&... Args) {
+ return SequenceSerialization<ChannelT, typename std::decay<ArgTs>::type...>::
+ serialize(C, std::forward<ArgTs>(Args)...);
}
template <typename ChannelT, typename... ArgTs>
@@ -290,6 +325,196 @@ Error deserializeSeq(ChannelT &C, ArgTs &... Args) {
return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...);
}
+template <typename ChannelT>
+class SerializationTraits<ChannelT, Error> {
+public:
+
+ using WrappedErrorSerializer =
+ std::function<Error(ChannelT &C, const ErrorInfoBase&)>;
+
+ using WrappedErrorDeserializer =
+ std::function<Error(ChannelT &C, Error &Err)>;
+
+ template <typename ErrorInfoT, typename SerializeFtor,
+ typename DeserializeFtor>
+ static void registerErrorType(std::string Name, SerializeFtor Serialize,
+ DeserializeFtor Deserialize) {
+ assert(!Name.empty() &&
+ "The empty string is reserved for the Success value");
+
+ const std::string *KeyName = nullptr;
+ {
+ // We're abusing the stability of std::map here: We take a reference to the
+ // key of the deserializers map to save us from duplicating the string in
+ // the serializer. This should be changed to use a stringpool if we switch
+ // to a map type that may move keys in memory.
+ std::lock_guard<std::mutex> Lock(DeserializersMutex);
+ auto I =
+ Deserializers.insert(Deserializers.begin(),
+ std::make_pair(std::move(Name),
+ std::move(Deserialize)));
+ KeyName = &I->first;
+ }
+
+ {
+ assert(KeyName != nullptr && "No keyname pointer");
+ std::lock_guard<std::mutex> Lock(SerializersMutex);
+ // FIXME: Move capture Serialize once we have C++14.
+ Serializers[ErrorInfoT::classID()] =
+ [KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error {
+ assert(EIB.dynamicClassID() == ErrorInfoT::classID() &&
+ "Serializer called for wrong error type");
+ if (auto Err = serializeSeq(C, *KeyName))
+ return Err;
+ return Serialize(C, static_cast<const ErrorInfoT&>(EIB));
+ };
+ }
+ }
+
+ static Error serialize(ChannelT &C, Error &&Err) {
+ std::lock_guard<std::mutex> Lock(SerializersMutex);
+ if (!Err)
+ return serializeSeq(C, std::string());
+
+ return handleErrors(std::move(Err),
+ [&C](const ErrorInfoBase &EIB) {
+ auto SI = Serializers.find(EIB.dynamicClassID());
+ if (SI == Serializers.end())
+ return serializeAsStringError(C, EIB);
+ return (SI->second)(C, EIB);
+ });
+ }
+
+ static Error deserialize(ChannelT &C, Error &Err) {
+ std::lock_guard<std::mutex> Lock(DeserializersMutex);
+
+ std::string Key;
+ if (auto Err = deserializeSeq(C, Key))
+ return Err;
+
+ if (Key.empty()) {
+ ErrorAsOutParameter EAO(&Err);
+ Err = Error::success();
+ return Error::success();
+ }
+
+ auto DI = Deserializers.find(Key);
+ assert(DI != Deserializers.end() && "No deserializer for error type");
+ return (DI->second)(C, Err);
+ }
+
+private:
+
+ static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) {
+ assert(EIB.dynamicClassID() != StringError::classID() &&
+ "StringError serialization not registered");
+ std::string ErrMsg;
+ {
+ raw_string_ostream ErrMsgStream(ErrMsg);
+ EIB.log(ErrMsgStream);
+ }
+ return serialize(C, make_error<StringError>(std::move(ErrMsg),
+ inconvertibleErrorCode()));
+ }
+
+ static std::mutex SerializersMutex;
+ static std::mutex DeserializersMutex;
+ static std::map<const void*, WrappedErrorSerializer> Serializers;
+ static std::map<std::string, WrappedErrorDeserializer> Deserializers;
+};
+
+template <typename ChannelT>
+std::mutex SerializationTraits<ChannelT, Error>::SerializersMutex;
+
+template <typename ChannelT>
+std::mutex SerializationTraits<ChannelT, Error>::DeserializersMutex;
+
+template <typename ChannelT>
+std::map<const void*,
+ typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer>
+SerializationTraits<ChannelT, Error>::Serializers;
+
+template <typename ChannelT>
+std::map<std::string,
+ typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer>
+SerializationTraits<ChannelT, Error>::Deserializers;
+
+template <typename ChannelT>
+void registerStringError() {
+ static bool AlreadyRegistered = false;
+ if (!AlreadyRegistered) {
+ SerializationTraits<ChannelT, Error>::
+ template registerErrorType<StringError>(
+ "StringError",
+ [](ChannelT &C, const StringError &SE) {
+ return serializeSeq(C, SE.getMessage());
+ },
+ [](ChannelT &C, Error &Err) {
+ ErrorAsOutParameter EAO(&Err);
+ std::string Msg;
+ if (auto E2 = deserializeSeq(C, Msg))
+ return E2;
+ Err =
+ make_error<StringError>(std::move(Msg),
+ orcError(
+ OrcErrorCode::UnknownErrorCodeFromRemote));
+ return Error::success();
+ });
+ AlreadyRegistered = true;
+ }
+}
+
+/// SerializationTraits for Expected<T1> from an Expected<T2>.
+template <typename ChannelT, typename T1, typename T2>
+class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> {
+public:
+
+ static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) {
+ if (ValOrErr) {
+ if (auto Err = serializeSeq(C, true))
+ return Err;
+ return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr);
+ }
+ if (auto Err = serializeSeq(C, false))
+ return Err;
+ return serializeSeq(C, ValOrErr.takeError());
+ }
+
+ static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) {
+ ExpectedAsOutParameter<T2> EAO(&ValOrErr);
+ bool HasValue;
+ if (auto Err = deserializeSeq(C, HasValue))
+ return Err;
+ if (HasValue)
+ return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr);
+ Error Err = Error::success();
+ if (auto E2 = deserializeSeq(C, Err))
+ return E2;
+ ValOrErr = std::move(Err);
+ return Error::success();
+ }
+};
+
+/// SerializationTraits for Expected<T1> from a T2.
+template <typename ChannelT, typename T1, typename T2>
+class SerializationTraits<ChannelT, Expected<T1>, T2> {
+public:
+
+ static Error serialize(ChannelT &C, T2 &&Val) {
+ return serializeSeq(C, Expected<T2>(std::forward<T2>(Val)));
+ }
+};
+
+/// SerializationTraits for Expected<T1> from an Error.
+template <typename ChannelT, typename T>
+class SerializationTraits<ChannelT, Expected<T>, Error> {
+public:
+
+ static Error serialize(ChannelT &C, Error &&Err) {
+ return serializeSeq(C, Expected<T>(std::move(Err)));
+ }
+};
+
/// SerializationTraits default specialization for std::pair.
template <typename ChannelT, typename T1, typename T2>
class SerializationTraits<ChannelT, std::pair<T1, T2>> {
diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h
index 37e2e66e5af48..6212f64ff3195 100644
--- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h
+++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h
@@ -26,27 +26,115 @@
#include "llvm/ExecutionEngine/Orc/OrcError.h"
#include "llvm/ExecutionEngine/Orc/RPCSerialization.h"
-#ifdef _MSC_VER
-// concrt.h depends on eh.h for __uncaught_exception declaration
-// even if we disable exceptions.
-#include <eh.h>
-
-// Disable warnings from ppltasks.h transitively included by <future>.
-#pragma warning(push)
-#pragma warning(disable : 4530)
-#pragma warning(disable : 4062)
-#endif
-
#include <future>
-#ifdef _MSC_VER
-#pragma warning(pop)
-#endif
-
namespace llvm {
namespace orc {
namespace rpc {
+/// Base class of all fatal RPC errors (those that necessarily result in the
+/// termination of the RPC session).
+class RPCFatalError : public ErrorInfo<RPCFatalError> {
+public:
+ static char ID;
+};
+
+/// RPCConnectionClosed is returned from RPC operations if the RPC connection
+/// has already been closed due to either an error or graceful disconnection.
+class ConnectionClosed : public ErrorInfo<ConnectionClosed> {
+public:
+ static char ID;
+ std::error_code convertToErrorCode() const override;
+ void log(raw_ostream &OS) const override;
+};
+
+/// BadFunctionCall is returned from handleOne when the remote makes a call with
+/// an unrecognized function id.
+///
+/// This error is fatal because Orc RPC needs to know how to parse a function
+/// call to know where the next call starts, and if it doesn't recognize the
+/// function id it cannot parse the call.
+template <typename FnIdT, typename SeqNoT>
+class BadFunctionCall
+ : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> {
+public:
+ static char ID;
+
+ BadFunctionCall(FnIdT FnId, SeqNoT SeqNo)
+ : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {}
+
+ std::error_code convertToErrorCode() const override {
+ return orcError(OrcErrorCode::UnexpectedRPCCall);
+ }
+
+ void log(raw_ostream &OS) const override {
+ OS << "Call to invalid RPC function id '" << FnId << "' with "
+ "sequence number " << SeqNo;
+ }
+
+private:
+ FnIdT FnId;
+ SeqNoT SeqNo;
+};
+
+template <typename FnIdT, typename SeqNoT>
+char BadFunctionCall<FnIdT, SeqNoT>::ID = 0;
+
+/// InvalidSequenceNumberForResponse is returned from handleOne when a response
+/// call arrives with a sequence number that doesn't correspond to any in-flight
+/// function call.
+///
+/// This error is fatal because Orc RPC needs to know how to parse the rest of
+/// the response call to know where the next call starts, and if it doesn't have
+/// a result parser for this sequence number it can't do that.
+template <typename SeqNoT>
+class InvalidSequenceNumberForResponse
+ : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> {
+public:
+ static char ID;
+
+ InvalidSequenceNumberForResponse(SeqNoT SeqNo)
+ : SeqNo(std::move(SeqNo)) {}
+
+ std::error_code convertToErrorCode() const override {
+ return orcError(OrcErrorCode::UnexpectedRPCCall);
+ };
+
+ void log(raw_ostream &OS) const override {
+ OS << "Response has unknown sequence number " << SeqNo;
+ }
+private:
+ SeqNoT SeqNo;
+};
+
+template <typename SeqNoT>
+char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0;
+
+/// This non-fatal error will be passed to asynchronous result handlers in place
+/// of a result if the connection goes down before a result returns, or if the
+/// function to be called cannot be negotiated with the remote.
+class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> {
+public:
+ static char ID;
+
+ std::error_code convertToErrorCode() const override;
+ void log(raw_ostream &OS) const override;
+};
+
+/// This error is returned if the remote does not have a handler installed for
+/// the given RPC function.
+class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> {
+public:
+ static char ID;
+
+ CouldNotNegotiate(std::string Signature);
+ std::error_code convertToErrorCode() const override;
+ void log(raw_ostream &OS) const override;
+ const std::string &getSignature() const { return Signature; }
+private:
+ std::string Signature;
+};
+
template <typename DerivedFunc, typename FnT> class Function;
// RPC Function class.
@@ -82,16 +170,6 @@ std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex;
template <typename DerivedFunc, typename RetT, typename... ArgTs>
std::string Function<DerivedFunc, RetT(ArgTs...)>::Name;
-/// Provides a typedef for a tuple containing the decayed argument types.
-template <typename T> class FunctionArgsTuple;
-
-template <typename RetT, typename... ArgTs>
-class FunctionArgsTuple<RetT(ArgTs...)> {
-public:
- using Type = std::tuple<typename std::decay<
- typename std::remove_reference<ArgTs>::type>::type...>;
-};
-
/// Allocates RPC function ids during autonegotiation.
/// Specializations of this class must provide four members:
///
@@ -196,6 +274,16 @@ public:
#endif // _MSC_VER
+/// Provides a typedef for a tuple containing the decayed argument types.
+template <typename T> class FunctionArgsTuple;
+
+template <typename RetT, typename... ArgTs>
+class FunctionArgsTuple<RetT(ArgTs...)> {
+public:
+ using Type = std::tuple<typename std::decay<
+ typename std::remove_reference<ArgTs>::type>::type...>;
+};
+
// ResultTraits provides typedefs and utilities specific to the return type
// of functions.
template <typename RetT> class ResultTraits {
@@ -274,43 +362,132 @@ template <> class ResultTraits<Error> : public ResultTraits<void> {};
template <typename RetT>
class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {};
+// Determines whether an RPC function's defined error return type supports
+// error return value.
+template <typename T>
+class SupportsErrorReturn {
+public:
+ static const bool value = false;
+};
+
+template <>
+class SupportsErrorReturn<Error> {
+public:
+ static const bool value = true;
+};
+
+template <typename T>
+class SupportsErrorReturn<Expected<T>> {
+public:
+ static const bool value = true;
+};
+
+// RespondHelper packages return values based on whether or not the declared
+// RPC function return type supports error returns.
+template <bool FuncSupportsErrorReturn>
+class RespondHelper;
+
+// RespondHelper specialization for functions that support error returns.
+template <>
+class RespondHelper<true> {
+public:
+
+ // Send Expected<T>.
+ template <typename WireRetT, typename HandlerRetT, typename ChannelT,
+ typename FunctionIdT, typename SequenceNumberT>
+ static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
+ SequenceNumberT SeqNo,
+ Expected<HandlerRetT> ResultOrErr) {
+ if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>())
+ return ResultOrErr.takeError();
+
+ // Open the response message.
+ if (auto Err = C.startSendMessage(ResponseId, SeqNo))
+ return Err;
+
+ // Serialize the result.
+ if (auto Err =
+ SerializationTraits<ChannelT, WireRetT,
+ Expected<HandlerRetT>>::serialize(
+ C, std::move(ResultOrErr)))
+ return Err;
+
+ // Close the response message.
+ return C.endSendMessage();
+ }
+
+ template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
+ static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
+ SequenceNumberT SeqNo, Error Err) {
+ if (Err && Err.isA<RPCFatalError>())
+ return Err;
+ if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
+ return Err2;
+ if (auto Err2 = serializeSeq(C, std::move(Err)))
+ return Err2;
+ return C.endSendMessage();
+ }
+
+};
+
+// RespondHelper specialization for functions that do not support error returns.
+template <>
+class RespondHelper<false> {
+public:
+
+ template <typename WireRetT, typename HandlerRetT, typename ChannelT,
+ typename FunctionIdT, typename SequenceNumberT>
+ static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
+ SequenceNumberT SeqNo,
+ Expected<HandlerRetT> ResultOrErr) {
+ if (auto Err = ResultOrErr.takeError())
+ return Err;
+
+ // Open the response message.
+ if (auto Err = C.startSendMessage(ResponseId, SeqNo))
+ return Err;
+
+ // Serialize the result.
+ if (auto Err =
+ SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
+ C, *ResultOrErr))
+ return Err;
+
+ // Close the response message.
+ return C.endSendMessage();
+ }
+
+ template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT>
+ static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId,
+ SequenceNumberT SeqNo, Error Err) {
+ if (Err)
+ return Err;
+ if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
+ return Err2;
+ return C.endSendMessage();
+ }
+
+};
+
+
// Send a response of the given wire return type (WireRetT) over the
// channel, with the given sequence number.
template <typename WireRetT, typename HandlerRetT, typename ChannelT,
typename FunctionIdT, typename SequenceNumberT>
-static Error respond(ChannelT &C, const FunctionIdT &ResponseId,
- SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
- // If this was an error bail out.
- // FIXME: Send an "error" message to the client if this is not a channel
- // failure?
- if (auto Err = ResultOrErr.takeError())
- return Err;
-
- // Open the response message.
- if (auto Err = C.startSendMessage(ResponseId, SeqNo))
- return Err;
-
- // Serialize the result.
- if (auto Err =
- SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize(
- C, *ResultOrErr))
- return Err;
-
- // Close the response message.
- return C.endSendMessage();
+Error respond(ChannelT &C, const FunctionIdT &ResponseId,
+ SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) {
+ return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
+ template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));
}
// Send an empty response message on the given channel to indicate that
// the handler ran.
template <typename WireRetT, typename ChannelT, typename FunctionIdT,
typename SequenceNumberT>
-static Error respond(ChannelT &C, const FunctionIdT &ResponseId,
- SequenceNumberT SeqNo, Error Err) {
- if (Err)
- return Err;
- if (auto Err2 = C.startSendMessage(ResponseId, SeqNo))
- return Err2;
- return C.endSendMessage();
+Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo,
+ Error Err) {
+ return RespondHelper<SupportsErrorReturn<WireRetT>::value>::
+ sendResult(C, ResponseId, SeqNo, std::move(Err));
}
// Converts a given type to the equivalent error return type.
@@ -339,6 +516,29 @@ public:
using Type = Error;
};
+// Traits class that strips the response function from the list of handler
+// arguments.
+template <typename FnT> class AsyncHandlerTraits;
+
+template <typename ResultT, typename... ArgTs>
+class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> {
+public:
+ using Type = Error(ArgTs...);
+ using ResultType = Expected<ResultT>;
+};
+
+template <typename... ArgTs>
+class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> {
+public:
+ using Type = Error(ArgTs...);
+ using ResultType = Error;
+};
+
+template <typename ResponseHandlerT, typename... ArgTs>
+class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> :
+ public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type,
+ ArgTs...)> {};
+
// This template class provides utilities related to RPC function handlers.
// The base case applies to non-function types (the template class is
// specialized for function types) and inherits from the appropriate
@@ -358,15 +558,20 @@ public:
// Return type of the handler.
using ReturnType = RetT;
- // A std::tuple wrapping the handler arguments.
- using ArgStorage = typename FunctionArgsTuple<RetT(ArgTs...)>::Type;
-
// Call the given handler with the given arguments.
- template <typename HandlerT>
+ template <typename HandlerT, typename... TArgTs>
static typename WrappedHandlerReturn<RetT>::Type
- unpackAndRun(HandlerT &Handler, ArgStorage &Args) {
+ unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {
return unpackAndRunHelper(Handler, Args,
- llvm::index_sequence_for<ArgTs...>());
+ llvm::index_sequence_for<TArgTs...>());
+ }
+
+ // Call the given handler with the given arguments.
+ template <typename HandlerT, typename ResponderT, typename... TArgTs>
+ static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder,
+ std::tuple<TArgTs...> &Args) {
+ return unpackAndRunAsyncHelper(Handler, Responder, Args,
+ llvm::index_sequence_for<TArgTs...>());
}
// Call the given handler with the given arguments.
@@ -379,11 +584,11 @@ public:
return Error::success();
}
- template <typename HandlerT>
+ template <typename HandlerT, typename... TArgTs>
static typename std::enable_if<
!std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,
typename HandlerTraits<HandlerT>::ReturnType>::type
- run(HandlerT &Handler, ArgTs... Args) {
+ run(HandlerT &Handler, TArgTs... Args) {
return Handler(std::move(Args)...);
}
@@ -408,15 +613,31 @@ private:
C, std::get<Indexes>(Args)...);
}
- template <typename HandlerT, size_t... Indexes>
+ template <typename HandlerT, typename ArgTuple, size_t... Indexes>
static typename WrappedHandlerReturn<
typename HandlerTraits<HandlerT>::ReturnType>::Type
- unpackAndRunHelper(HandlerT &Handler, ArgStorage &Args,
+ unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,
llvm::index_sequence<Indexes...>) {
return run(Handler, std::move(std::get<Indexes>(Args))...);
}
+
+
+ template <typename HandlerT, typename ResponderT, typename ArgTuple,
+ size_t... Indexes>
+ static typename WrappedHandlerReturn<
+ typename HandlerTraits<HandlerT>::ReturnType>::Type
+ unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder,
+ ArgTuple &Args,
+ llvm::index_sequence<Indexes...>) {
+ return run(Handler, Responder, std::move(std::get<Indexes>(Args))...);
+ }
};
+// Handler traits for free functions.
+template <typename RetT, typename... ArgTs>
+class HandlerTraits<RetT(*)(ArgTs...)>
+ : public HandlerTraits<RetT(ArgTs...)> {};
+
// Handler traits for class methods (especially call operators for lambdas).
template <typename Class, typename RetT, typename... ArgTs>
class HandlerTraits<RetT (Class::*)(ArgTs...)>
@@ -471,7 +692,7 @@ public:
// Create an error instance representing an abandoned response.
static Error createAbandonedResponseError() {
- return orcError(OrcErrorCode::RPCResponseAbandoned);
+ return make_error<ResponseAbandoned>();
}
};
@@ -493,7 +714,7 @@ public:
return Err;
if (auto Err = C.endReceiveMessage())
return Err;
- return Handler(Result);
+ return Handler(std::move(Result));
}
// Abandon this response by calling the handler with an 'abandoned response'
@@ -538,6 +759,72 @@ private:
HandlerT Handler;
};
+template <typename ChannelT, typename FuncRetT, typename HandlerT>
+class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT>
+ : public ResponseHandler<ChannelT> {
+public:
+ ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
+
+ // Handle the result by deserializing it from the channel then passing it
+ // to the user defined handler.
+ Error handleResponse(ChannelT &C) override {
+ using HandlerArgType = typename ResponseHandlerArg<
+ typename HandlerTraits<HandlerT>::Type>::ArgType;
+ HandlerArgType Result((typename HandlerArgType::value_type()));
+
+ if (auto Err =
+ SerializationTraits<ChannelT, Expected<FuncRetT>,
+ HandlerArgType>::deserialize(C, Result))
+ return Err;
+ if (auto Err = C.endReceiveMessage())
+ return Err;
+ return Handler(std::move(Result));
+ }
+
+ // Abandon this response by calling the handler with an 'abandoned response'
+ // error.
+ void abandon() override {
+ if (auto Err = Handler(this->createAbandonedResponseError())) {
+ // Handlers should not fail when passed an abandoned response error.
+ report_fatal_error(std::move(Err));
+ }
+ }
+
+private:
+ HandlerT Handler;
+};
+
+template <typename ChannelT, typename HandlerT>
+class ResponseHandlerImpl<ChannelT, Error, HandlerT>
+ : public ResponseHandler<ChannelT> {
+public:
+ ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {}
+
+ // Handle the result by deserializing it from the channel then passing it
+ // to the user defined handler.
+ Error handleResponse(ChannelT &C) override {
+ Error Result = Error::success();
+ if (auto Err =
+ SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result))
+ return Err;
+ if (auto Err = C.endReceiveMessage())
+ return Err;
+ return Handler(std::move(Result));
+ }
+
+ // Abandon this response by calling the handler with an 'abandoned response'
+ // error.
+ void abandon() override {
+ if (auto Err = Handler(this->createAbandonedResponseError())) {
+ // Handlers should not fail when passed an abandoned response error.
+ report_fatal_error(std::move(Err));
+ }
+ }
+
+private:
+ HandlerT Handler;
+};
+
// Create a ResponseHandler from a given user handler.
template <typename ChannelT, typename FuncRetT, typename HandlerT>
std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) {
@@ -758,8 +1045,13 @@ public:
auto NegotiateId = FnIdAllocator.getNegotiateId();
RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;
Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>(
- [this](const std::string &Name) { return handleNegotiate(Name); },
- LaunchPolicy());
+ [this](const std::string &Name) { return handleNegotiate(Name); });
+ }
+
+
+ /// Negotiate a function id for Func with the other end of the channel.
+ template <typename Func> Error negotiateFunction(bool Retry = false) {
+ return getRemoteFunctionId<Func>(true, Retry).takeError();
}
/// Append a call Func, does not call send on the channel.
@@ -777,14 +1069,12 @@ public:
// Look up the function ID.
FunctionIdT FnId;
- if (auto FnIdOrErr = getRemoteFunctionId<Func>())
+ if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))
FnId = *FnIdOrErr;
else {
- // This isn't a channel error so we don't want to abandon other pending
- // responses, but we still need to run the user handler with an error to
- // let them know the call failed.
- if (auto Err = Handler(orcError(OrcErrorCode::UnknownRPCFunction)))
- report_fatal_error(std::move(Err));
+ // Negotiation failed. Notify the handler then return the negotiate-failed
+ // error.
+ cantFail(Handler(make_error<ResponseAbandoned>()));
return FnIdOrErr.takeError();
}
@@ -807,20 +1097,20 @@ public:
// Open the function call message.
if (auto Err = C.startSendMessage(FnId, SeqNo)) {
abandonPendingResponses();
- return joinErrors(std::move(Err), C.endSendMessage());
+ return Err;
}
// Serialize the call arguments.
if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(
C, Args...)) {
abandonPendingResponses();
- return joinErrors(std::move(Err), C.endSendMessage());
+ return Err;
}
// Close the function call messagee.
if (auto Err = C.endSendMessage()) {
abandonPendingResponses();
- return std::move(Err);
+ return Err;
}
return Error::success();
@@ -839,8 +1129,10 @@ public:
Error handleOne() {
FunctionIdT FnId;
SequenceNumberT SeqNo;
- if (auto Err = C.startReceiveMessage(FnId, SeqNo))
+ if (auto Err = C.startReceiveMessage(FnId, SeqNo)) {
+ abandonPendingResponses();
return Err;
+ }
if (FnId == ResponseId)
return handleResponse(SeqNo);
auto I = Handlers.find(FnId);
@@ -848,7 +1140,8 @@ public:
return I->second(C, SeqNo);
// else: No handler found. Report error to client?
- return orcError(OrcErrorCode::UnexpectedRPCCall);
+ return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId,
+ SeqNo);
}
/// Helper for handling setter procedures - this method returns a functor that
@@ -887,10 +1180,25 @@ public:
SequenceNumberMgr.reset();
}
+ /// Remove the handler for the given function.
+ /// A handler must currently be registered for this function.
+ template <typename Func>
+ void removeHandler() {
+ auto IdItr = LocalFunctionIds.find(Func::getPrototype());
+ assert(IdItr != LocalFunctionIds.end() &&
+ "Function does not have a registered handler");
+ auto HandlerItr = Handlers.find(IdItr->second);
+ assert(HandlerItr != Handlers.end() &&
+ "Function does not have a registered handler");
+ Handlers.erase(HandlerItr);
+ }
+
+ /// Clear all handlers.
+ void clearHandlers() {
+ Handlers.clear();
+ }
+
protected:
- // The LaunchPolicy type allows a launch policy to be specified when adding
- // a function handler. See addHandlerImpl.
- using LaunchPolicy = std::function<Error(std::function<Error()>)>;
FunctionIdT getInvalidFunctionId() const {
return FnIdAllocator.getInvalidId();
@@ -899,7 +1207,7 @@ protected:
/// Add the given handler to the handler map and make it available for
/// autonegotiation and execution.
template <typename Func, typename HandlerT>
- void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) {
+ void addHandlerImpl(HandlerT Handler) {
static_assert(detail::RPCArgTypeCheck<
CanDeserializeCheck, typename Func::Type,
@@ -908,8 +1216,22 @@ protected:
FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
LocalFunctionIds[Func::getPrototype()] = NewFnId;
- Handlers[NewFnId] =
- wrapHandler<Func>(std::move(Handler), std::move(Launch));
+ Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler));
+ }
+
+ template <typename Func, typename HandlerT>
+ void addAsyncHandlerImpl(HandlerT Handler) {
+
+ static_assert(detail::RPCArgTypeCheck<
+ CanDeserializeCheck, typename Func::Type,
+ typename detail::AsyncHandlerTraits<
+ typename detail::HandlerTraits<HandlerT>::Type
+ >::Type>::value,
+ "");
+
+ FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();
+ LocalFunctionIds[Func::getPrototype()] = NewFnId;
+ Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));
}
Error handleResponse(SequenceNumberT SeqNo) {
@@ -929,7 +1251,8 @@ protected:
// Unlock the pending results map to prevent recursive lock.
Lock.unlock();
abandonPendingResponses();
- return orcError(OrcErrorCode::UnexpectedRPCResponse);
+ return make_error<
+ InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);
}
}
@@ -951,41 +1274,39 @@ protected:
return I->second;
}
- // Find the remote FunctionId for the given function, which must be in the
- // RemoteFunctionIds map.
- template <typename Func> Expected<FunctionIdT> getRemoteFunctionId() {
- // Try to find the id for the given function.
- auto I = RemoteFunctionIds.find(Func::getPrototype());
+ // Find the remote FunctionId for the given function.
+ template <typename Func>
+ Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap,
+ bool NegotiateIfInvalid) {
+ bool DoNegotiate;
- // If we have it in the map, return it.
- if (I != RemoteFunctionIds.end())
- return I->second;
+ // Check if we already have a function id...
+ auto I = RemoteFunctionIds.find(Func::getPrototype());
+ if (I != RemoteFunctionIds.end()) {
+ // If it's valid there's nothing left to do.
+ if (I->second != getInvalidFunctionId())
+ return I->second;
+ DoNegotiate = NegotiateIfInvalid;
+ } else
+ DoNegotiate = NegotiateIfNotInMap;
- // Otherwise, if we have auto-negotiation enabled, try to negotiate it.
- if (LazyAutoNegotiation) {
+ // We don't have a function id for Func yet, but we're allowed to try to
+ // negotiate one.
+ if (DoNegotiate) {
auto &Impl = static_cast<ImplT &>(*this);
if (auto RemoteIdOrErr =
- Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
- auto &RemoteId = *RemoteIdOrErr;
-
- // If autonegotiation indicates that the remote end doesn't support this
- // function, return an unknown function error.
- if (RemoteId == getInvalidFunctionId())
- return orcError(OrcErrorCode::UnknownRPCFunction);
-
- // Autonegotiation succeeded and returned a valid id. Update the map and
- // return the id.
- RemoteFunctionIds[Func::getPrototype()] = RemoteId;
- return RemoteId;
- } else {
- // Autonegotiation failed. Return the error.
+ Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) {
+ RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
+ if (*RemoteIdOrErr == getInvalidFunctionId())
+ return make_error<CouldNotNegotiate>(Func::getPrototype());
+ return *RemoteIdOrErr;
+ } else
return RemoteIdOrErr.takeError();
- }
}
- // No key was available in the map and autonegotiation wasn't enabled.
- // Return an unknown function error.
- return orcError(OrcErrorCode::UnknownRPCFunction);
+ // No key was available in the map and we weren't allowed to try to
+ // negotiate one, so return an unknown function error.
+ return make_error<CouldNotNegotiate>(Func::getPrototype());
}
using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>;
@@ -993,12 +1314,15 @@ protected:
// Wrap the given user handler in the necessary argument-deserialization code,
// result-serialization code, and call to the launch policy (if present).
template <typename Func, typename HandlerT>
- WrappedHandlerFn wrapHandler(HandlerT Handler, LaunchPolicy Launch) {
- return [this, Handler, Launch](ChannelT &Channel,
- SequenceNumberT SeqNo) mutable -> Error {
+ WrappedHandlerFn wrapHandler(HandlerT Handler) {
+ return [this, Handler](ChannelT &Channel,
+ SequenceNumberT SeqNo) mutable -> Error {
// Start by deserializing the arguments.
- auto Args = std::make_shared<
- typename detail::HandlerTraits<HandlerT>::ArgStorage>();
+ using ArgsTuple =
+ typename detail::FunctionArgsTuple<
+ typename detail::HandlerTraits<HandlerT>::Type>::Type;
+ auto Args = std::make_shared<ArgsTuple>();
+
if (auto Err =
detail::HandlerTraits<typename Func::Type>::deserializeArgs(
Channel, *Args))
@@ -1013,22 +1337,49 @@ protected:
if (auto Err = Channel.endReceiveMessage())
return Err;
- // Build the handler/responder.
- auto Responder = [this, Handler, Args, &Channel,
- SeqNo]() mutable -> Error {
- using HTraits = detail::HandlerTraits<HandlerT>;
- using FuncReturn = typename Func::ReturnType;
- return detail::respond<FuncReturn>(
- Channel, ResponseId, SeqNo, HTraits::unpackAndRun(Handler, *Args));
- };
-
- // If there is an explicit launch policy then use it to launch the
- // handler.
- if (Launch)
- return Launch(std::move(Responder));
-
- // Otherwise run the handler on the listener thread.
- return Responder();
+ using HTraits = detail::HandlerTraits<HandlerT>;
+ using FuncReturn = typename Func::ReturnType;
+ return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo,
+ HTraits::unpackAndRun(Handler, *Args));
+ };
+ }
+
+ // Wrap the given user handler in the necessary argument-deserialization code,
+ // result-serialization code, and call to the launch policy (if present).
+ template <typename Func, typename HandlerT>
+ WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) {
+ return [this, Handler](ChannelT &Channel,
+ SequenceNumberT SeqNo) mutable -> Error {
+ // Start by deserializing the arguments.
+ using AHTraits = detail::AsyncHandlerTraits<
+ typename detail::HandlerTraits<HandlerT>::Type>;
+ using ArgsTuple =
+ typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type;
+ auto Args = std::make_shared<ArgsTuple>();
+
+ if (auto Err =
+ detail::HandlerTraits<typename Func::Type>::deserializeArgs(
+ Channel, *Args))
+ return Err;
+
+ // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning
+ // for RPCArgs. Void cast RPCArgs to work around this for now.
+ // FIXME: Remove this workaround once we can assume a working GCC version.
+ (void)Args;
+
+ // End receieve message, unlocking the channel for reading.
+ if (auto Err = Channel.endReceiveMessage())
+ return Err;
+
+ using HTraits = detail::HandlerTraits<HandlerT>;
+ using FuncReturn = typename Func::ReturnType;
+ auto Responder =
+ [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error {
+ return detail::respond<FuncReturn>(C, ResponseId, SeqNo,
+ std::move(RetVal));
+ };
+
+ return HTraits::unpackAndRunAsync(Handler, Responder, *Args);
};
}
@@ -1068,66 +1419,31 @@ public:
MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
: BaseClass(C, LazyAutoNegotiation) {}
- /// The LaunchPolicy type allows a launch policy to be specified when adding
- /// a function handler. See addHandler.
- using LaunchPolicy = typename BaseClass::LaunchPolicy;
-
/// Add a handler for the given RPC function.
/// This installs the given handler functor for the given RPC Function, and
/// makes the RPC function available for negotiation/calling from the remote.
- ///
- /// The optional LaunchPolicy argument can be used to control how the handler
- /// is run when called:
- ///
- /// * If no LaunchPolicy is given, the handler code will be run on the RPC
- /// handler thread that is reading from the channel. This handler cannot
- /// make blocking RPC calls (since it would be blocking the thread used to
- /// get the result), but can make non-blocking calls.
- ///
- /// * If a LaunchPolicy is given, the user's handler will be wrapped in a
- /// call to serialize and send the result, and the resulting functor (with
- /// type 'Error()' will be passed to the LaunchPolicy. The user can then
- /// choose to add the wrapped handler to a work queue, spawn a new thread,
- /// or anything else.
template <typename Func, typename HandlerT>
- void addHandler(HandlerT Handler, LaunchPolicy Launch = LaunchPolicy()) {
- return this->template addHandlerImpl<Func>(std::move(Handler),
- std::move(Launch));
+ void addHandler(HandlerT Handler) {
+ return this->template addHandlerImpl<Func>(std::move(Handler));
}
/// Add a class-method as a handler.
template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
- void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...),
- LaunchPolicy Launch = LaunchPolicy()) {
+ void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
addHandler<Func>(
- detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method),
- Launch);
+ detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
}
- /// Negotiate a function id for Func with the other end of the channel.
- template <typename Func> Error negotiateFunction(bool Retry = false) {
- using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate;
-
- // Check if we already have a function id...
- auto I = this->RemoteFunctionIds.find(Func::getPrototype());
- if (I != this->RemoteFunctionIds.end()) {
- // If it's valid there's nothing left to do.
- if (I->second != this->getInvalidFunctionId())
- return Error::success();
- // If it's invalid and we can't re-attempt negotiation, throw an error.
- if (!Retry)
- return orcError(OrcErrorCode::UnknownRPCFunction);
- }
+ template <typename Func, typename HandlerT>
+ void addAsyncHandler(HandlerT Handler) {
+ return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
+ }
- // We don't have a function id for Func yet, call the remote to try to
- // negotiate one.
- if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) {
- this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
- if (*RemoteIdOrErr == this->getInvalidFunctionId())
- return orcError(OrcErrorCode::UnknownRPCFunction);
- return Error::success();
- } else
- return RemoteIdOrErr.takeError();
+ /// Add a class-method as a handler.
+ template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
+ void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
+ addAsyncHandler<Func>(
+ detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
}
/// Return type for non-blocking call primitives.
@@ -1159,7 +1475,6 @@ public:
return Error::success();
},
Args...)) {
- this->abandonPendingResponses();
RTraits::consumeAbandoned(FutureResult.get());
return std::move(Err);
}
@@ -1191,15 +1506,9 @@ public:
typename AltRetT = typename Func::ReturnType>
typename detail::ResultTraits<AltRetT>::ErrorReturnType
callB(const ArgTs &... Args) {
- if (auto FutureResOrErr = callNB<Func>(Args...)) {
- if (auto Err = this->C.send()) {
- this->abandonPendingResponses();
- detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
- std::move(FutureResOrErr->get()));
- return std::move(Err);
- }
+ if (auto FutureResOrErr = callNB<Func>(Args...))
return FutureResOrErr->get();
- } else
+ else
return FutureResOrErr.takeError();
}
@@ -1224,16 +1533,13 @@ private:
SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,
ChannelT, FunctionIdT, SequenceNumberT>;
- using LaunchPolicy = typename BaseClass::LaunchPolicy;
-
public:
SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)
: BaseClass(C, LazyAutoNegotiation) {}
template <typename Func, typename HandlerT>
void addHandler(HandlerT Handler) {
- return this->template addHandlerImpl<Func>(std::move(Handler),
- LaunchPolicy());
+ return this->template addHandlerImpl<Func>(std::move(Handler));
}
template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
@@ -1242,30 +1548,16 @@ public:
detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
}
- /// Negotiate a function id for Func with the other end of the channel.
- template <typename Func> Error negotiateFunction(bool Retry = false) {
- using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate;
-
- // Check if we already have a function id...
- auto I = this->RemoteFunctionIds.find(Func::getPrototype());
- if (I != this->RemoteFunctionIds.end()) {
- // If it's valid there's nothing left to do.
- if (I->second != this->getInvalidFunctionId())
- return Error::success();
- // If it's invalid and we can't re-attempt negotiation, throw an error.
- if (!Retry)
- return orcError(OrcErrorCode::UnknownRPCFunction);
- }
+ template <typename Func, typename HandlerT>
+ void addAsyncHandler(HandlerT Handler) {
+ return this->template addAsyncHandlerImpl<Func>(std::move(Handler));
+ }
- // We don't have a function id for Func yet, call the remote to try to
- // negotiate one.
- if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) {
- this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr;
- if (*RemoteIdOrErr == this->getInvalidFunctionId())
- return orcError(OrcErrorCode::UnknownRPCFunction);
- return Error::success();
- } else
- return RemoteIdOrErr.takeError();
+ /// Add a class-method as a handler.
+ template <typename Func, typename ClassT, typename RetT, typename... ArgTs>
+ void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {
+ addAsyncHandler<Func>(
+ detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));
}
template <typename Func, typename... ArgTs,
@@ -1287,7 +1579,6 @@ public:
return Error::success();
},
Args...)) {
- this->abandonPendingResponses();
detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
std::move(Result));
return std::move(Err);
@@ -1295,7 +1586,6 @@ public:
while (!ReceivedResponse) {
if (auto Err = this->handleOne()) {
- this->abandonPendingResponses();
detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(
std::move(Result));
return std::move(Err);
@@ -1306,24 +1596,40 @@ public:
}
};
+/// Asynchronous dispatch for a function on an RPC endpoint.
+template <typename RPCClass, typename Func>
+class RPCAsyncDispatch {
+public:
+ RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
+
+ template <typename HandlerT, typename... ArgTs>
+ Error operator()(HandlerT Handler, const ArgTs &... Args) const {
+ return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
+ }
+
+private:
+ RPCClass &Endpoint;
+};
+
+/// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
+template <typename Func, typename RPCEndpointT>
+RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
+ return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
+}
+
/// \brief Allows a set of asynchrounous calls to be dispatched, and then
/// waited on as a group.
-template <typename RPCClass> class ParallelCallGroup {
+class ParallelCallGroup {
public:
- /// \brief Construct a parallel call group for the given RPC.
- ParallelCallGroup(RPCClass &RPC) : RPC(RPC), NumOutstandingCalls(0) {}
-
+ ParallelCallGroup() = default;
ParallelCallGroup(const ParallelCallGroup &) = delete;
ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
/// \brief Make as asynchronous call.
- ///
- /// Does not issue a send call to the RPC's channel. The channel may use this
- /// to batch up subsequent calls. A send will automatically be sent when wait
- /// is called.
- template <typename Func, typename HandlerT, typename... ArgTs>
- Error appendCall(HandlerT Handler, const ArgTs &... Args) {
+ template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
+ Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
+ const ArgTs &... Args) {
// Increment the count of outstanding calls. This has to happen before
// we invoke the call, as the handler may (depending on scheduling)
// be run immediately on another thread, and we don't want the decrement
@@ -1346,38 +1652,21 @@ public:
return Err;
};
- return RPC.template appendCallAsync<Func>(std::move(WrappedHandler),
- Args...);
- }
-
- /// \brief Make an asynchronous call.
- ///
- /// The same as appendCall, but also calls send on the channel immediately.
- /// Prefer appendCall if you are about to issue a "wait" call shortly, as
- /// this may allow the channel to better batch the calls.
- template <typename Func, typename HandlerT, typename... ArgTs>
- Error call(HandlerT Handler, const ArgTs &... Args) {
- if (auto Err = appendCall(std::move(Handler), Args...))
- return Err;
- return RPC.sendAppendedCalls();
+ return AsyncDispatch(std::move(WrappedHandler), Args...);
}
/// \brief Blocks until all calls have been completed and their return value
/// handlers run.
- Error wait() {
- if (auto Err = RPC.sendAppendedCalls())
- return Err;
+ void wait() {
std::unique_lock<std::mutex> Lock(M);
while (NumOutstandingCalls > 0)
CV.wait(Lock);
- return Error::success();
}
private:
- RPCClass &RPC;
std::mutex M;
std::condition_variable CV;
- uint32_t NumOutstandingCalls;
+ uint32_t NumOutstandingCalls = 0;
};
/// @brief Convenience class for grouping RPC Functions into APIs that can be
diff --git a/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h b/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h
index 0588d22285981..babcc7f26aab5 100644
--- a/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h
+++ b/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h
@@ -1,4 +1,4 @@
-//===- ObjectLinkingLayer.h - Add object files to a JIT process -*- C++ -*-===//
+//===-- RTDyldObjectLinkingLayer.h - RTDyld-based jit linking --*- C++ -*-===//
//
// The LLVM Compiler Infrastructure
//
@@ -7,12 +7,12 @@
//
//===----------------------------------------------------------------------===//
//
-// Contains the definition for the object layer of the JIT.
+// Contains the definition for an RTDyld-based, in-process object linking layer.
//
//===----------------------------------------------------------------------===//
-#ifndef LLVM_EXECUTIONENGINE_ORC_OBJECTLINKINGLAYER_H
-#define LLVM_EXECUTIONENGINE_ORC_OBJECTLINKINGLAYER_H
+#ifndef LLVM_EXECUTIONENGINE_ORC_RTDYLDOBJECTLINKINGLAYER_H
+#define LLVM_EXECUTIONENGINE_ORC_RTDYLDOBJECTLINKINGLAYER_H
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringMap.h"
@@ -35,7 +35,7 @@
namespace llvm {
namespace orc {
-class ObjectLinkingLayerBase {
+class RTDyldObjectLinkingLayerBase {
protected:
/// @brief Holds a set of objects to be allocated/linked as a unit in the JIT.
///
@@ -87,7 +87,7 @@ public:
class DoNothingOnNotifyLoaded {
public:
template <typename ObjSetT, typename LoadResult>
- void operator()(ObjectLinkingLayerBase::ObjSetHandleT, const ObjSetT &,
+ void operator()(RTDyldObjectLinkingLayerBase::ObjSetHandleT, const ObjSetT &,
const LoadResult &) {}
};
@@ -98,7 +98,7 @@ public:
/// symbols queried. All objects added to this layer can see each other's
/// symbols.
template <typename NotifyLoadedFtor = DoNothingOnNotifyLoaded>
-class ObjectLinkingLayer : public ObjectLinkingLayerBase {
+class RTDyldObjectLinkingLayer : public RTDyldObjectLinkingLayerBase {
public:
/// @brief Functor for receiving finalization notifications.
typedef std::function<void(ObjSetHandleT)> NotifyFinalizedFtor;
@@ -227,7 +227,7 @@ public:
/// @brief Construct an ObjectLinkingLayer with the given NotifyLoaded,
/// and NotifyFinalized functors.
- ObjectLinkingLayer(
+ RTDyldObjectLinkingLayer(
NotifyLoadedFtor NotifyLoaded = NotifyLoadedFtor(),
NotifyFinalizedFtor NotifyFinalized = NotifyFinalizedFtor())
: NotifyLoaded(std::move(NotifyLoaded)),
@@ -359,4 +359,4 @@ private:
} // end namespace orc
} // end namespace llvm
-#endif // LLVM_EXECUTIONENGINE_ORC_OBJECTLINKINGLAYER_H
+#endif // LLVM_EXECUTIONENGINE_ORC_RTDYLDOBJECTLINKINGLAYER_H
diff --git a/include/llvm/ExecutionEngine/Orc/RawByteChannel.h b/include/llvm/ExecutionEngine/Orc/RawByteChannel.h
index 3b6c84eb19650..52a546f7c6eb9 100644
--- a/include/llvm/ExecutionEngine/Orc/RawByteChannel.h
+++ b/include/llvm/ExecutionEngine/Orc/RawByteChannel.h
@@ -48,7 +48,11 @@ public:
template <typename FunctionIdT, typename SequenceIdT>
Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
writeLock.lock();
- return serializeSeq(*this, FnId, SeqNo);
+ if (auto Err = serializeSeq(*this, FnId, SeqNo)) {
+ writeLock.unlock();
+ return Err;
+ }
+ return Error::success();
}
/// Notify the channel that we're ending a message send.
@@ -63,7 +67,11 @@ public:
template <typename FunctionIdT, typename SequenceNumberT>
Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
readLock.lock();
- return deserializeSeq(*this, FnId, SeqNo);
+ if (auto Err = deserializeSeq(*this, FnId, SeqNo)) {
+ readLock.unlock();
+ return Err;
+ }
+ return Error::success();
}
/// Notify the channel that we're ending a message receive.
@@ -113,11 +121,19 @@ class SerializationTraits<ChannelT, bool, bool,
RawByteChannel, ChannelT>::value>::type> {
public:
static Error serialize(ChannelT &C, bool V) {
- return C.appendBytes(reinterpret_cast<const char *>(&V), 1);
+ uint8_t Tmp = V ? 1 : 0;
+ if (auto Err =
+ C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1))
+ return Err;
+ return Error::success();
}
static Error deserialize(ChannelT &C, bool &V) {
- return C.readBytes(reinterpret_cast<char *>(&V), 1);
+ uint8_t Tmp = 0;
+ if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1))
+ return Err;
+ V = Tmp != 0;
+ return Error::success();
}
};
@@ -134,10 +150,12 @@ public:
}
};
-template <typename ChannelT>
-class SerializationTraits<ChannelT, std::string, const char *,
- typename std::enable_if<std::is_base_of<
- RawByteChannel, ChannelT>::value>::type> {
+template <typename ChannelT, typename T>
+class SerializationTraits<ChannelT, std::string, T,
+ typename std::enable_if<
+ std::is_base_of<RawByteChannel, ChannelT>::value &&
+ (std::is_same<T, const char*>::value ||
+ std::is_same<T, char*>::value)>::type> {
public:
static Error serialize(RawByteChannel &C, const char *S) {
return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C,