diff options
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc')
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h | 12 | ||||
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/OrcError.h | 6 | ||||
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h | 1 | ||||
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h | 23 | ||||
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/RPCSerialization.h | 243 | ||||
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/RPCUtils.h | 787 | ||||
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h (renamed from include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h) | 18 | ||||
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/RawByteChannel.h | 34 |
8 files changed, 833 insertions, 291 deletions
diff --git a/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h b/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h index aa096478cd9e7..7e7f7358938a4 100644 --- a/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h +++ b/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h @@ -376,7 +376,7 @@ private: // Initializers may refer to functions declared (but not defined) in this // module. Build a materializer to clone decls on demand. auto Materializer = createLambdaMaterializer( - [this, &LD, &GVsM](Value *V) -> Value* { + [&LD, &GVsM](Value *V) -> Value* { if (auto *F = dyn_cast<Function>(V)) { // Decls in the original module just get cloned. if (F->isDeclaration()) @@ -419,7 +419,7 @@ private: // Build a resolver for the globals module and add it to the base layer. auto GVsResolver = createLambdaResolver( - [this, &LD, LMId](const std::string &Name) { + [this, &LD](const std::string &Name) { if (auto Sym = LD.StubsMgr->findStub(Name, false)) return Sym; if (auto Sym = LD.findSymbol(BaseLayer, Name, false)) @@ -499,8 +499,8 @@ private: M->setDataLayout(SrcM.getDataLayout()); ValueToValueMapTy VMap; - auto Materializer = createLambdaMaterializer([this, &LD, &LMId, &M, - &VMap](Value *V) -> Value * { + auto Materializer = createLambdaMaterializer([&LD, &LMId, + &M](Value *V) -> Value * { if (auto *GV = dyn_cast<GlobalVariable>(V)) return cloneGlobalVariableDecl(*M, *GV); @@ -546,12 +546,12 @@ private: // Create memory manager and symbol resolver. auto Resolver = createLambdaResolver( - [this, &LD, LMId](const std::string &Name) { + [this, &LD](const std::string &Name) { if (auto Sym = LD.findSymbol(BaseLayer, Name, false)) return Sym; return LD.ExternalSymbolResolver->findSymbolInLogicalDylib(Name); }, - [this, &LD](const std::string &Name) { + [&LD](const std::string &Name) { return LD.ExternalSymbolResolver->findSymbol(Name); }); diff --git a/include/llvm/ExecutionEngine/Orc/OrcError.h b/include/llvm/ExecutionEngine/Orc/OrcError.h index b74988cce2fb0..cbb40fad02230 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcError.h +++ b/include/llvm/ExecutionEngine/Orc/OrcError.h @@ -27,13 +27,15 @@ enum class OrcErrorCode : int { RemoteMProtectAddrUnrecognized, RemoteIndirectStubsOwnerDoesNotExist, RemoteIndirectStubsOwnerIdAlreadyInUse, + RPCConnectionClosed, + RPCCouldNotNegotiateFunction, RPCResponseAbandoned, UnexpectedRPCCall, UnexpectedRPCResponse, - UnknownRPCFunction + UnknownErrorCodeFromRemote }; -Error orcError(OrcErrorCode ErrCode); +std::error_code orcError(OrcErrorCode ErrCode); } // End namespace orc. } // End namespace llvm. diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h index 8647db56cd2f5..02f59d6a831a8 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h @@ -18,6 +18,7 @@ #include "IndirectionUtils.h" #include "OrcRemoteTargetRPCAPI.h" +#include "llvm/ExecutionEngine/RuntimeDyld.h" #include <system_error> #define DEBUG_TYPE "orc-remote" diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h index 506330fe3a5e9..a61ff102be0b0 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h @@ -132,7 +132,7 @@ private: Error setProtections(void *block, unsigned Flags) { auto I = Allocs.find(block); if (I == Allocs.end()) - return orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized); + return errorCodeToError(orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized)); return errorCodeToError( sys::Memory::protectMappedMemory(I->second, Flags)); } @@ -198,7 +198,8 @@ private: Error handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) { auto I = Allocators.find(Id); if (I != Allocators.end()) - return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse); + return errorCodeToError( + orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse)); DEBUG(dbgs() << " Created allocator " << Id << "\n"); Allocators[Id] = Allocator(); return Error::success(); @@ -207,7 +208,8 @@ private: Error handleCreateIndirectStubsOwner(ResourceIdMgr::ResourceId Id) { auto I = IndirectStubsOwners.find(Id); if (I != IndirectStubsOwners.end()) - return orcError(OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse); + return errorCodeToError( + orcError(OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse)); DEBUG(dbgs() << " Create indirect stubs owner " << Id << "\n"); IndirectStubsOwners[Id] = ISBlockOwnerList(); return Error::success(); @@ -224,7 +226,8 @@ private: Error handleDestroyRemoteAllocator(ResourceIdMgr::ResourceId Id) { auto I = Allocators.find(Id); if (I == Allocators.end()) - return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); + return errorCodeToError( + orcError(OrcErrorCode::RemoteAllocatorDoesNotExist)); Allocators.erase(I); DEBUG(dbgs() << " Destroyed allocator " << Id << "\n"); return Error::success(); @@ -233,7 +236,8 @@ private: Error handleDestroyIndirectStubsOwner(ResourceIdMgr::ResourceId Id) { auto I = IndirectStubsOwners.find(Id); if (I == IndirectStubsOwners.end()) - return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist); + return errorCodeToError( + orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist)); IndirectStubsOwners.erase(I); return Error::success(); } @@ -246,7 +250,8 @@ private: auto StubOwnerItr = IndirectStubsOwners.find(Id); if (StubOwnerItr == IndirectStubsOwners.end()) - return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist); + return errorCodeToError( + orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist)); typename TargetT::IndirectStubsInfo IS; if (auto Err = @@ -361,7 +366,8 @@ private: uint64_t Size, uint32_t Align) { auto I = Allocators.find(Id); if (I == Allocators.end()) - return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); + return errorCodeToError( + orcError(OrcErrorCode::RemoteAllocatorDoesNotExist)); auto &Allocator = I->second; void *LocalAllocAddr = nullptr; if (auto Err = Allocator.allocate(LocalAllocAddr, Size, Align)) @@ -380,7 +386,8 @@ private: JITTargetAddress Addr, uint32_t Flags) { auto I = Allocators.find(Id); if (I == Allocators.end()) - return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); + return errorCodeToError( + orcError(OrcErrorCode::RemoteAllocatorDoesNotExist)); auto &Allocator = I->second; void *LocalAddr = reinterpret_cast<void *>(static_cast<uintptr_t>(Addr)); DEBUG(dbgs() << " Allocator " << Id << " set permissions on " << LocalAddr diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h index 359a9d81b22b5..84a037b2f998b 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h +++ b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -12,6 +12,7 @@ #include "OrcError.h" #include "llvm/Support/thread.h" +#include <map> #include <mutex> #include <sstream> @@ -114,6 +115,35 @@ public: static const char* getName() { return "std::string"; } }; +template <> +class RPCTypeName<Error> { +public: + static const char* getName() { return "Error"; } +}; + +template <typename T> +class RPCTypeName<Expected<T>> { +public: + static const char* getName() { + std::lock_guard<std::mutex> Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) << "Expected<" + << RPCTypeNameSequence<T>() + << ">"; + return Name.data(); + } + +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template <typename T> +std::mutex RPCTypeName<Expected<T>>::NameMutex; + +template <typename T> +std::string RPCTypeName<Expected<T>>::Name; + template <typename T1, typename T2> class RPCTypeName<std::pair<T1, T2>> { public: @@ -243,8 +273,10 @@ class SequenceSerialization<ChannelT, ArgT> { public: template <typename CArgT> - static Error serialize(ChannelT &C, const CArgT &CArg) { - return SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg); + static Error serialize(ChannelT &C, CArgT &&CArg) { + return SerializationTraits<ChannelT, ArgT, + typename std::decay<CArgT>::type>:: + serialize(C, std::forward<CArgT>(CArg)); } template <typename CArgT> @@ -258,19 +290,21 @@ class SequenceSerialization<ChannelT, ArgT, ArgTs...> { public: template <typename CArgT, typename... CArgTs> - static Error serialize(ChannelT &C, const CArgT &CArg, - const CArgTs&... CArgs) { + static Error serialize(ChannelT &C, CArgT &&CArg, + CArgTs &&... CArgs) { if (auto Err = - SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg)) + SerializationTraits<ChannelT, ArgT, typename std::decay<CArgT>::type>:: + serialize(C, std::forward<CArgT>(CArg))) return Err; if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C)) return Err; - return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); + return SequenceSerialization<ChannelT, ArgTs...>:: + serialize(C, std::forward<CArgTs>(CArgs)...); } template <typename CArgT, typename... CArgTs> static Error deserialize(ChannelT &C, CArgT &CArg, - CArgTs&... CArgs) { + CArgTs &... CArgs) { if (auto Err = SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg)) return Err; @@ -281,8 +315,9 @@ public: }; template <typename ChannelT, typename... ArgTs> -Error serializeSeq(ChannelT &C, const ArgTs &... Args) { - return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, Args...); +Error serializeSeq(ChannelT &C, ArgTs &&... Args) { + return SequenceSerialization<ChannelT, typename std::decay<ArgTs>::type...>:: + serialize(C, std::forward<ArgTs>(Args)...); } template <typename ChannelT, typename... ArgTs> @@ -290,6 +325,196 @@ Error deserializeSeq(ChannelT &C, ArgTs &... Args) { return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...); } +template <typename ChannelT> +class SerializationTraits<ChannelT, Error> { +public: + + using WrappedErrorSerializer = + std::function<Error(ChannelT &C, const ErrorInfoBase&)>; + + using WrappedErrorDeserializer = + std::function<Error(ChannelT &C, Error &Err)>; + + template <typename ErrorInfoT, typename SerializeFtor, + typename DeserializeFtor> + static void registerErrorType(std::string Name, SerializeFtor Serialize, + DeserializeFtor Deserialize) { + assert(!Name.empty() && + "The empty string is reserved for the Success value"); + + const std::string *KeyName = nullptr; + { + // We're abusing the stability of std::map here: We take a reference to the + // key of the deserializers map to save us from duplicating the string in + // the serializer. This should be changed to use a stringpool if we switch + // to a map type that may move keys in memory. + std::lock_guard<std::mutex> Lock(DeserializersMutex); + auto I = + Deserializers.insert(Deserializers.begin(), + std::make_pair(std::move(Name), + std::move(Deserialize))); + KeyName = &I->first; + } + + { + assert(KeyName != nullptr && "No keyname pointer"); + std::lock_guard<std::mutex> Lock(SerializersMutex); + // FIXME: Move capture Serialize once we have C++14. + Serializers[ErrorInfoT::classID()] = + [KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error { + assert(EIB.dynamicClassID() == ErrorInfoT::classID() && + "Serializer called for wrong error type"); + if (auto Err = serializeSeq(C, *KeyName)) + return Err; + return Serialize(C, static_cast<const ErrorInfoT&>(EIB)); + }; + } + } + + static Error serialize(ChannelT &C, Error &&Err) { + std::lock_guard<std::mutex> Lock(SerializersMutex); + if (!Err) + return serializeSeq(C, std::string()); + + return handleErrors(std::move(Err), + [&C](const ErrorInfoBase &EIB) { + auto SI = Serializers.find(EIB.dynamicClassID()); + if (SI == Serializers.end()) + return serializeAsStringError(C, EIB); + return (SI->second)(C, EIB); + }); + } + + static Error deserialize(ChannelT &C, Error &Err) { + std::lock_guard<std::mutex> Lock(DeserializersMutex); + + std::string Key; + if (auto Err = deserializeSeq(C, Key)) + return Err; + + if (Key.empty()) { + ErrorAsOutParameter EAO(&Err); + Err = Error::success(); + return Error::success(); + } + + auto DI = Deserializers.find(Key); + assert(DI != Deserializers.end() && "No deserializer for error type"); + return (DI->second)(C, Err); + } + +private: + + static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { + assert(EIB.dynamicClassID() != StringError::classID() && + "StringError serialization not registered"); + std::string ErrMsg; + { + raw_string_ostream ErrMsgStream(ErrMsg); + EIB.log(ErrMsgStream); + } + return serialize(C, make_error<StringError>(std::move(ErrMsg), + inconvertibleErrorCode())); + } + + static std::mutex SerializersMutex; + static std::mutex DeserializersMutex; + static std::map<const void*, WrappedErrorSerializer> Serializers; + static std::map<std::string, WrappedErrorDeserializer> Deserializers; +}; + +template <typename ChannelT> +std::mutex SerializationTraits<ChannelT, Error>::SerializersMutex; + +template <typename ChannelT> +std::mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; + +template <typename ChannelT> +std::map<const void*, + typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer> +SerializationTraits<ChannelT, Error>::Serializers; + +template <typename ChannelT> +std::map<std::string, + typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer> +SerializationTraits<ChannelT, Error>::Deserializers; + +template <typename ChannelT> +void registerStringError() { + static bool AlreadyRegistered = false; + if (!AlreadyRegistered) { + SerializationTraits<ChannelT, Error>:: + template registerErrorType<StringError>( + "StringError", + [](ChannelT &C, const StringError &SE) { + return serializeSeq(C, SE.getMessage()); + }, + [](ChannelT &C, Error &Err) { + ErrorAsOutParameter EAO(&Err); + std::string Msg; + if (auto E2 = deserializeSeq(C, Msg)) + return E2; + Err = + make_error<StringError>(std::move(Msg), + orcError( + OrcErrorCode::UnknownErrorCodeFromRemote)); + return Error::success(); + }); + AlreadyRegistered = true; + } +} + +/// SerializationTraits for Expected<T1> from an Expected<T2>. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> { +public: + + static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) { + if (ValOrErr) { + if (auto Err = serializeSeq(C, true)) + return Err; + return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr); + } + if (auto Err = serializeSeq(C, false)) + return Err; + return serializeSeq(C, ValOrErr.takeError()); + } + + static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) { + ExpectedAsOutParameter<T2> EAO(&ValOrErr); + bool HasValue; + if (auto Err = deserializeSeq(C, HasValue)) + return Err; + if (HasValue) + return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr); + Error Err = Error::success(); + if (auto E2 = deserializeSeq(C, Err)) + return E2; + ValOrErr = std::move(Err); + return Error::success(); + } +}; + +/// SerializationTraits for Expected<T1> from a T2. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, T2> { +public: + + static Error serialize(ChannelT &C, T2 &&Val) { + return serializeSeq(C, Expected<T2>(std::forward<T2>(Val))); + } +}; + +/// SerializationTraits for Expected<T1> from an Error. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, Expected<T>, Error> { +public: + + static Error serialize(ChannelT &C, Error &&Err) { + return serializeSeq(C, Expected<T>(std::move(Err))); + } +}; + /// SerializationTraits default specialization for std::pair. template <typename ChannelT, typename T1, typename T2> class SerializationTraits<ChannelT, std::pair<T1, T2>> { diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 37e2e66e5af48..6212f64ff3195 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -26,27 +26,115 @@ #include "llvm/ExecutionEngine/Orc/OrcError.h" #include "llvm/ExecutionEngine/Orc/RPCSerialization.h" -#ifdef _MSC_VER -// concrt.h depends on eh.h for __uncaught_exception declaration -// even if we disable exceptions. -#include <eh.h> - -// Disable warnings from ppltasks.h transitively included by <future>. -#pragma warning(push) -#pragma warning(disable : 4530) -#pragma warning(disable : 4062) -#endif - #include <future> -#ifdef _MSC_VER -#pragma warning(pop) -#endif - namespace llvm { namespace orc { namespace rpc { +/// Base class of all fatal RPC errors (those that necessarily result in the +/// termination of the RPC session). +class RPCFatalError : public ErrorInfo<RPCFatalError> { +public: + static char ID; +}; + +/// RPCConnectionClosed is returned from RPC operations if the RPC connection +/// has already been closed due to either an error or graceful disconnection. +class ConnectionClosed : public ErrorInfo<ConnectionClosed> { +public: + static char ID; + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; +}; + +/// BadFunctionCall is returned from handleOne when the remote makes a call with +/// an unrecognized function id. +/// +/// This error is fatal because Orc RPC needs to know how to parse a function +/// call to know where the next call starts, and if it doesn't recognize the +/// function id it cannot parse the call. +template <typename FnIdT, typename SeqNoT> +class BadFunctionCall + : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> { +public: + static char ID; + + BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) + : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnexpectedRPCCall); + } + + void log(raw_ostream &OS) const override { + OS << "Call to invalid RPC function id '" << FnId << "' with " + "sequence number " << SeqNo; + } + +private: + FnIdT FnId; + SeqNoT SeqNo; +}; + +template <typename FnIdT, typename SeqNoT> +char BadFunctionCall<FnIdT, SeqNoT>::ID = 0; + +/// InvalidSequenceNumberForResponse is returned from handleOne when a response +/// call arrives with a sequence number that doesn't correspond to any in-flight +/// function call. +/// +/// This error is fatal because Orc RPC needs to know how to parse the rest of +/// the response call to know where the next call starts, and if it doesn't have +/// a result parser for this sequence number it can't do that. +template <typename SeqNoT> +class InvalidSequenceNumberForResponse + : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> { +public: + static char ID; + + InvalidSequenceNumberForResponse(SeqNoT SeqNo) + : SeqNo(std::move(SeqNo)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnexpectedRPCCall); + }; + + void log(raw_ostream &OS) const override { + OS << "Response has unknown sequence number " << SeqNo; + } +private: + SeqNoT SeqNo; +}; + +template <typename SeqNoT> +char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0; + +/// This non-fatal error will be passed to asynchronous result handlers in place +/// of a result if the connection goes down before a result returns, or if the +/// function to be called cannot be negotiated with the remote. +class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> { +public: + static char ID; + + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; +}; + +/// This error is returned if the remote does not have a handler installed for +/// the given RPC function. +class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> { +public: + static char ID; + + CouldNotNegotiate(std::string Signature); + std::error_code convertToErrorCode() const override; + void log(raw_ostream &OS) const override; + const std::string &getSignature() const { return Signature; } +private: + std::string Signature; +}; + template <typename DerivedFunc, typename FnT> class Function; // RPC Function class. @@ -82,16 +170,6 @@ std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex; template <typename DerivedFunc, typename RetT, typename... ArgTs> std::string Function<DerivedFunc, RetT(ArgTs...)>::Name; -/// Provides a typedef for a tuple containing the decayed argument types. -template <typename T> class FunctionArgsTuple; - -template <typename RetT, typename... ArgTs> -class FunctionArgsTuple<RetT(ArgTs...)> { -public: - using Type = std::tuple<typename std::decay< - typename std::remove_reference<ArgTs>::type>::type...>; -}; - /// Allocates RPC function ids during autonegotiation. /// Specializations of this class must provide four members: /// @@ -196,6 +274,16 @@ public: #endif // _MSC_VER +/// Provides a typedef for a tuple containing the decayed argument types. +template <typename T> class FunctionArgsTuple; + +template <typename RetT, typename... ArgTs> +class FunctionArgsTuple<RetT(ArgTs...)> { +public: + using Type = std::tuple<typename std::decay< + typename std::remove_reference<ArgTs>::type>::type...>; +}; + // ResultTraits provides typedefs and utilities specific to the return type // of functions. template <typename RetT> class ResultTraits { @@ -274,43 +362,132 @@ template <> class ResultTraits<Error> : public ResultTraits<void> {}; template <typename RetT> class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; +// Determines whether an RPC function's defined error return type supports +// error return value. +template <typename T> +class SupportsErrorReturn { +public: + static const bool value = false; +}; + +template <> +class SupportsErrorReturn<Error> { +public: + static const bool value = true; +}; + +template <typename T> +class SupportsErrorReturn<Expected<T>> { +public: + static const bool value = true; +}; + +// RespondHelper packages return values based on whether or not the declared +// RPC function return type supports error returns. +template <bool FuncSupportsErrorReturn> +class RespondHelper; + +// RespondHelper specialization for functions that support error returns. +template <> +class RespondHelper<true> { +public: + + // Send Expected<T>. + template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected<HandlerRetT> ResultOrErr) { + if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>()) + return ResultOrErr.takeError(); + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits<ChannelT, WireRetT, + Expected<HandlerRetT>>::serialize( + C, std::move(ResultOrErr))) + return Err; + + // Close the response message. + return C.endSendMessage(); + } + + template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err && Err.isA<RPCFatalError>()) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + if (auto Err2 = serializeSeq(C, std::move(Err))) + return Err2; + return C.endSendMessage(); + } + +}; + +// RespondHelper specialization for functions that do not support error returns. +template <> +class RespondHelper<false> { +public: + + template <typename WireRetT, typename HandlerRetT, typename ChannelT, + typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected<HandlerRetT> ResultOrErr) { + if (auto Err = ResultOrErr.takeError()) + return Err; + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( + C, *ResultOrErr)) + return Err; + + // Close the response message. + return C.endSendMessage(); + } + + template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + return C.endSendMessage(); + } + +}; + + // Send a response of the given wire return type (WireRetT) over the // channel, with the given sequence number. template <typename WireRetT, typename HandlerRetT, typename ChannelT, typename FunctionIdT, typename SequenceNumberT> -static Error respond(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) { - // If this was an error bail out. - // FIXME: Send an "error" message to the client if this is not a channel - // failure? - if (auto Err = ResultOrErr.takeError()) - return Err; - - // Open the response message. - if (auto Err = C.startSendMessage(ResponseId, SeqNo)) - return Err; - - // Serialize the result. - if (auto Err = - SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( - C, *ResultOrErr)) - return Err; - - // Close the response message. - return C.endSendMessage(); +Error respond(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) { + return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: + template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr)); } // Send an empty response message on the given channel to indicate that // the handler ran. template <typename WireRetT, typename ChannelT, typename FunctionIdT, typename SequenceNumberT> -static Error respond(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, Error Err) { - if (Err) - return Err; - if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) - return Err2; - return C.endSendMessage(); +Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, + Error Err) { + return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: + sendResult(C, ResponseId, SeqNo, std::move(Err)); } // Converts a given type to the equivalent error return type. @@ -339,6 +516,29 @@ public: using Type = Error; }; +// Traits class that strips the response function from the list of handler +// arguments. +template <typename FnT> class AsyncHandlerTraits; + +template <typename ResultT, typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Expected<ResultT>; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { +public: + using Type = Error(ArgTs...); + using ResultType = Error; +}; + +template <typename ResponseHandlerT, typename... ArgTs> +class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> : + public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type, + ArgTs...)> {}; + // This template class provides utilities related to RPC function handlers. // The base case applies to non-function types (the template class is // specialized for function types) and inherits from the appropriate @@ -358,15 +558,20 @@ public: // Return type of the handler. using ReturnType = RetT; - // A std::tuple wrapping the handler arguments. - using ArgStorage = typename FunctionArgsTuple<RetT(ArgTs...)>::Type; - // Call the given handler with the given arguments. - template <typename HandlerT> + template <typename HandlerT, typename... TArgTs> static typename WrappedHandlerReturn<RetT>::Type - unpackAndRun(HandlerT &Handler, ArgStorage &Args) { + unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) { return unpackAndRunHelper(Handler, Args, - llvm::index_sequence_for<ArgTs...>()); + llvm::index_sequence_for<TArgTs...>()); + } + + // Call the given handler with the given arguments. + template <typename HandlerT, typename ResponderT, typename... TArgTs> + static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, + std::tuple<TArgTs...> &Args) { + return unpackAndRunAsyncHelper(Handler, Responder, Args, + llvm::index_sequence_for<TArgTs...>()); } // Call the given handler with the given arguments. @@ -379,11 +584,11 @@ public: return Error::success(); } - template <typename HandlerT> + template <typename HandlerT, typename... TArgTs> static typename std::enable_if< !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value, typename HandlerTraits<HandlerT>::ReturnType>::type - run(HandlerT &Handler, ArgTs... Args) { + run(HandlerT &Handler, TArgTs... Args) { return Handler(std::move(Args)...); } @@ -408,15 +613,31 @@ private: C, std::get<Indexes>(Args)...); } - template <typename HandlerT, size_t... Indexes> + template <typename HandlerT, typename ArgTuple, size_t... Indexes> static typename WrappedHandlerReturn< typename HandlerTraits<HandlerT>::ReturnType>::Type - unpackAndRunHelper(HandlerT &Handler, ArgStorage &Args, + unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args, llvm::index_sequence<Indexes...>) { return run(Handler, std::move(std::get<Indexes>(Args))...); } + + + template <typename HandlerT, typename ResponderT, typename ArgTuple, + size_t... Indexes> + static typename WrappedHandlerReturn< + typename HandlerTraits<HandlerT>::ReturnType>::Type + unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, + ArgTuple &Args, + llvm::index_sequence<Indexes...>) { + return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); + } }; +// Handler traits for free functions. +template <typename RetT, typename... ArgTs> +class HandlerTraits<RetT(*)(ArgTs...)> + : public HandlerTraits<RetT(ArgTs...)> {}; + // Handler traits for class methods (especially call operators for lambdas). template <typename Class, typename RetT, typename... ArgTs> class HandlerTraits<RetT (Class::*)(ArgTs...)> @@ -471,7 +692,7 @@ public: // Create an error instance representing an abandoned response. static Error createAbandonedResponseError() { - return orcError(OrcErrorCode::RPCResponseAbandoned); + return make_error<ResponseAbandoned>(); } }; @@ -493,7 +714,7 @@ public: return Err; if (auto Err = C.endReceiveMessage()) return Err; - return Handler(Result); + return Handler(std::move(Result)); } // Abandon this response by calling the handler with an 'abandoned response' @@ -538,6 +759,72 @@ private: HandlerT Handler; }; +template <typename ChannelT, typename FuncRetT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using HandlerArgType = typename ResponseHandlerArg< + typename HandlerTraits<HandlerT>::Type>::ArgType; + HandlerArgType Result((typename HandlerArgType::value_type())); + + if (auto Err = + SerializationTraits<ChannelT, Expected<FuncRetT>, + HandlerArgType>::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +template <typename ChannelT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, Error, HandlerT> + : public ResponseHandler<ChannelT> { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + Error Result = Error::success(); + if (auto Err = + SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + // Create a ResponseHandler from a given user handler. template <typename ChannelT, typename FuncRetT, typename HandlerT> std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) { @@ -758,8 +1045,13 @@ public: auto NegotiateId = FnIdAllocator.getNegotiateId(); RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( - [this](const std::string &Name) { return handleNegotiate(Name); }, - LaunchPolicy()); + [this](const std::string &Name) { return handleNegotiate(Name); }); + } + + + /// Negotiate a function id for Func with the other end of the channel. + template <typename Func> Error negotiateFunction(bool Retry = false) { + return getRemoteFunctionId<Func>(true, Retry).takeError(); } /// Append a call Func, does not call send on the channel. @@ -777,14 +1069,12 @@ public: // Look up the function ID. FunctionIdT FnId; - if (auto FnIdOrErr = getRemoteFunctionId<Func>()) + if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false)) FnId = *FnIdOrErr; else { - // This isn't a channel error so we don't want to abandon other pending - // responses, but we still need to run the user handler with an error to - // let them know the call failed. - if (auto Err = Handler(orcError(OrcErrorCode::UnknownRPCFunction))) - report_fatal_error(std::move(Err)); + // Negotiation failed. Notify the handler then return the negotiate-failed + // error. + cantFail(Handler(make_error<ResponseAbandoned>())); return FnIdOrErr.takeError(); } @@ -807,20 +1097,20 @@ public: // Open the function call message. if (auto Err = C.startSendMessage(FnId, SeqNo)) { abandonPendingResponses(); - return joinErrors(std::move(Err), C.endSendMessage()); + return Err; } // Serialize the call arguments. if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs( C, Args...)) { abandonPendingResponses(); - return joinErrors(std::move(Err), C.endSendMessage()); + return Err; } // Close the function call messagee. if (auto Err = C.endSendMessage()) { abandonPendingResponses(); - return std::move(Err); + return Err; } return Error::success(); @@ -839,8 +1129,10 @@ public: Error handleOne() { FunctionIdT FnId; SequenceNumberT SeqNo; - if (auto Err = C.startReceiveMessage(FnId, SeqNo)) + if (auto Err = C.startReceiveMessage(FnId, SeqNo)) { + abandonPendingResponses(); return Err; + } if (FnId == ResponseId) return handleResponse(SeqNo); auto I = Handlers.find(FnId); @@ -848,7 +1140,8 @@ public: return I->second(C, SeqNo); // else: No handler found. Report error to client? - return orcError(OrcErrorCode::UnexpectedRPCCall); + return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId, + SeqNo); } /// Helper for handling setter procedures - this method returns a functor that @@ -887,10 +1180,25 @@ public: SequenceNumberMgr.reset(); } + /// Remove the handler for the given function. + /// A handler must currently be registered for this function. + template <typename Func> + void removeHandler() { + auto IdItr = LocalFunctionIds.find(Func::getPrototype()); + assert(IdItr != LocalFunctionIds.end() && + "Function does not have a registered handler"); + auto HandlerItr = Handlers.find(IdItr->second); + assert(HandlerItr != Handlers.end() && + "Function does not have a registered handler"); + Handlers.erase(HandlerItr); + } + + /// Clear all handlers. + void clearHandlers() { + Handlers.clear(); + } + protected: - // The LaunchPolicy type allows a launch policy to be specified when adding - // a function handler. See addHandlerImpl. - using LaunchPolicy = std::function<Error(std::function<Error()>)>; FunctionIdT getInvalidFunctionId() const { return FnIdAllocator.getInvalidId(); @@ -899,7 +1207,7 @@ protected: /// Add the given handler to the handler map and make it available for /// autonegotiation and execution. template <typename Func, typename HandlerT> - void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) { + void addHandlerImpl(HandlerT Handler) { static_assert(detail::RPCArgTypeCheck< CanDeserializeCheck, typename Func::Type, @@ -908,8 +1216,22 @@ protected: FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); LocalFunctionIds[Func::getPrototype()] = NewFnId; - Handlers[NewFnId] = - wrapHandler<Func>(std::move(Handler), std::move(Launch)); + Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); + } + + template <typename Func, typename HandlerT> + void addAsyncHandlerImpl(HandlerT Handler) { + + static_assert(detail::RPCArgTypeCheck< + CanDeserializeCheck, typename Func::Type, + typename detail::AsyncHandlerTraits< + typename detail::HandlerTraits<HandlerT>::Type + >::Type>::value, + ""); + + FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); + LocalFunctionIds[Func::getPrototype()] = NewFnId; + Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler)); } Error handleResponse(SequenceNumberT SeqNo) { @@ -929,7 +1251,8 @@ protected: // Unlock the pending results map to prevent recursive lock. Lock.unlock(); abandonPendingResponses(); - return orcError(OrcErrorCode::UnexpectedRPCResponse); + return make_error< + InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo); } } @@ -951,41 +1274,39 @@ protected: return I->second; } - // Find the remote FunctionId for the given function, which must be in the - // RemoteFunctionIds map. - template <typename Func> Expected<FunctionIdT> getRemoteFunctionId() { - // Try to find the id for the given function. - auto I = RemoteFunctionIds.find(Func::getPrototype()); + // Find the remote FunctionId for the given function. + template <typename Func> + Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap, + bool NegotiateIfInvalid) { + bool DoNegotiate; - // If we have it in the map, return it. - if (I != RemoteFunctionIds.end()) - return I->second; + // Check if we already have a function id... + auto I = RemoteFunctionIds.find(Func::getPrototype()); + if (I != RemoteFunctionIds.end()) { + // If it's valid there's nothing left to do. + if (I->second != getInvalidFunctionId()) + return I->second; + DoNegotiate = NegotiateIfInvalid; + } else + DoNegotiate = NegotiateIfNotInMap; - // Otherwise, if we have auto-negotiation enabled, try to negotiate it. - if (LazyAutoNegotiation) { + // We don't have a function id for Func yet, but we're allowed to try to + // negotiate one. + if (DoNegotiate) { auto &Impl = static_cast<ImplT &>(*this); if (auto RemoteIdOrErr = - Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { - auto &RemoteId = *RemoteIdOrErr; - - // If autonegotiation indicates that the remote end doesn't support this - // function, return an unknown function error. - if (RemoteId == getInvalidFunctionId()) - return orcError(OrcErrorCode::UnknownRPCFunction); - - // Autonegotiation succeeded and returned a valid id. Update the map and - // return the id. - RemoteFunctionIds[Func::getPrototype()] = RemoteId; - return RemoteId; - } else { - // Autonegotiation failed. Return the error. + Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { + RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + if (*RemoteIdOrErr == getInvalidFunctionId()) + return make_error<CouldNotNegotiate>(Func::getPrototype()); + return *RemoteIdOrErr; + } else return RemoteIdOrErr.takeError(); - } } - // No key was available in the map and autonegotiation wasn't enabled. - // Return an unknown function error. - return orcError(OrcErrorCode::UnknownRPCFunction); + // No key was available in the map and we weren't allowed to try to + // negotiate one, so return an unknown function error. + return make_error<CouldNotNegotiate>(Func::getPrototype()); } using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>; @@ -993,12 +1314,15 @@ protected: // Wrap the given user handler in the necessary argument-deserialization code, // result-serialization code, and call to the launch policy (if present). template <typename Func, typename HandlerT> - WrappedHandlerFn wrapHandler(HandlerT Handler, LaunchPolicy Launch) { - return [this, Handler, Launch](ChannelT &Channel, - SequenceNumberT SeqNo) mutable -> Error { + WrappedHandlerFn wrapHandler(HandlerT Handler) { + return [this, Handler](ChannelT &Channel, + SequenceNumberT SeqNo) mutable -> Error { // Start by deserializing the arguments. - auto Args = std::make_shared< - typename detail::HandlerTraits<HandlerT>::ArgStorage>(); + using ArgsTuple = + typename detail::FunctionArgsTuple< + typename detail::HandlerTraits<HandlerT>::Type>::Type; + auto Args = std::make_shared<ArgsTuple>(); + if (auto Err = detail::HandlerTraits<typename Func::Type>::deserializeArgs( Channel, *Args)) @@ -1013,22 +1337,49 @@ protected: if (auto Err = Channel.endReceiveMessage()) return Err; - // Build the handler/responder. - auto Responder = [this, Handler, Args, &Channel, - SeqNo]() mutable -> Error { - using HTraits = detail::HandlerTraits<HandlerT>; - using FuncReturn = typename Func::ReturnType; - return detail::respond<FuncReturn>( - Channel, ResponseId, SeqNo, HTraits::unpackAndRun(Handler, *Args)); - }; - - // If there is an explicit launch policy then use it to launch the - // handler. - if (Launch) - return Launch(std::move(Responder)); - - // Otherwise run the handler on the listener thread. - return Responder(); + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, + HTraits::unpackAndRun(Handler, *Args)); + }; + } + + // Wrap the given user handler in the necessary argument-deserialization code, + // result-serialization code, and call to the launch policy (if present). + template <typename Func, typename HandlerT> + WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { + return [this, Handler](ChannelT &Channel, + SequenceNumberT SeqNo) mutable -> Error { + // Start by deserializing the arguments. + using AHTraits = detail::AsyncHandlerTraits< + typename detail::HandlerTraits<HandlerT>::Type>; + using ArgsTuple = + typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type; + auto Args = std::make_shared<ArgsTuple>(); + + if (auto Err = + detail::HandlerTraits<typename Func::Type>::deserializeArgs( + Channel, *Args)) + return Err; + + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)Args; + + // End receieve message, unlocking the channel for reading. + if (auto Err = Channel.endReceiveMessage()) + return Err; + + using HTraits = detail::HandlerTraits<HandlerT>; + using FuncReturn = typename Func::ReturnType; + auto Responder = + [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error { + return detail::respond<FuncReturn>(C, ResponseId, SeqNo, + std::move(RetVal)); + }; + + return HTraits::unpackAndRunAsync(Handler, Responder, *Args); }; } @@ -1068,66 +1419,31 @@ public: MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) : BaseClass(C, LazyAutoNegotiation) {} - /// The LaunchPolicy type allows a launch policy to be specified when adding - /// a function handler. See addHandler. - using LaunchPolicy = typename BaseClass::LaunchPolicy; - /// Add a handler for the given RPC function. /// This installs the given handler functor for the given RPC Function, and /// makes the RPC function available for negotiation/calling from the remote. - /// - /// The optional LaunchPolicy argument can be used to control how the handler - /// is run when called: - /// - /// * If no LaunchPolicy is given, the handler code will be run on the RPC - /// handler thread that is reading from the channel. This handler cannot - /// make blocking RPC calls (since it would be blocking the thread used to - /// get the result), but can make non-blocking calls. - /// - /// * If a LaunchPolicy is given, the user's handler will be wrapped in a - /// call to serialize and send the result, and the resulting functor (with - /// type 'Error()' will be passed to the LaunchPolicy. The user can then - /// choose to add the wrapped handler to a work queue, spawn a new thread, - /// or anything else. template <typename Func, typename HandlerT> - void addHandler(HandlerT Handler, LaunchPolicy Launch = LaunchPolicy()) { - return this->template addHandlerImpl<Func>(std::move(Handler), - std::move(Launch)); + void addHandler(HandlerT Handler) { + return this->template addHandlerImpl<Func>(std::move(Handler)); } /// Add a class-method as a handler. template <typename Func, typename ClassT, typename RetT, typename... ArgTs> - void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...), - LaunchPolicy Launch = LaunchPolicy()) { + void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { addHandler<Func>( - detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method), - Launch); + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); } - /// Negotiate a function id for Func with the other end of the channel. - template <typename Func> Error negotiateFunction(bool Retry = false) { - using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; - - // Check if we already have a function id... - auto I = this->RemoteFunctionIds.find(Func::getPrototype()); - if (I != this->RemoteFunctionIds.end()) { - // If it's valid there's nothing left to do. - if (I->second != this->getInvalidFunctionId()) - return Error::success(); - // If it's invalid and we can't re-attempt negotiation, throw an error. - if (!Retry) - return orcError(OrcErrorCode::UnknownRPCFunction); - } + template <typename Func, typename HandlerT> + void addAsyncHandler(HandlerT Handler) { + return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); + } - // We don't have a function id for Func yet, call the remote to try to - // negotiate one. - if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) { - this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; - if (*RemoteIdOrErr == this->getInvalidFunctionId()) - return orcError(OrcErrorCode::UnknownRPCFunction); - return Error::success(); - } else - return RemoteIdOrErr.takeError(); + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addAsyncHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); } /// Return type for non-blocking call primitives. @@ -1159,7 +1475,6 @@ public: return Error::success(); }, Args...)) { - this->abandonPendingResponses(); RTraits::consumeAbandoned(FutureResult.get()); return std::move(Err); } @@ -1191,15 +1506,9 @@ public: typename AltRetT = typename Func::ReturnType> typename detail::ResultTraits<AltRetT>::ErrorReturnType callB(const ArgTs &... Args) { - if (auto FutureResOrErr = callNB<Func>(Args...)) { - if (auto Err = this->C.send()) { - this->abandonPendingResponses(); - detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( - std::move(FutureResOrErr->get())); - return std::move(Err); - } + if (auto FutureResOrErr = callNB<Func>(Args...)) return FutureResOrErr->get(); - } else + else return FutureResOrErr.takeError(); } @@ -1224,16 +1533,13 @@ private: SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>, ChannelT, FunctionIdT, SequenceNumberT>; - using LaunchPolicy = typename BaseClass::LaunchPolicy; - public: SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation) : BaseClass(C, LazyAutoNegotiation) {} template <typename Func, typename HandlerT> void addHandler(HandlerT Handler) { - return this->template addHandlerImpl<Func>(std::move(Handler), - LaunchPolicy()); + return this->template addHandlerImpl<Func>(std::move(Handler)); } template <typename Func, typename ClassT, typename RetT, typename... ArgTs> @@ -1242,30 +1548,16 @@ public: detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); } - /// Negotiate a function id for Func with the other end of the channel. - template <typename Func> Error negotiateFunction(bool Retry = false) { - using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; - - // Check if we already have a function id... - auto I = this->RemoteFunctionIds.find(Func::getPrototype()); - if (I != this->RemoteFunctionIds.end()) { - // If it's valid there's nothing left to do. - if (I->second != this->getInvalidFunctionId()) - return Error::success(); - // If it's invalid and we can't re-attempt negotiation, throw an error. - if (!Retry) - return orcError(OrcErrorCode::UnknownRPCFunction); - } + template <typename Func, typename HandlerT> + void addAsyncHandler(HandlerT Handler) { + return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); + } - // We don't have a function id for Func yet, call the remote to try to - // negotiate one. - if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) { - this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; - if (*RemoteIdOrErr == this->getInvalidFunctionId()) - return orcError(OrcErrorCode::UnknownRPCFunction); - return Error::success(); - } else - return RemoteIdOrErr.takeError(); + /// Add a class-method as a handler. + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addAsyncHandler<Func>( + detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method)); } template <typename Func, typename... ArgTs, @@ -1287,7 +1579,6 @@ public: return Error::success(); }, Args...)) { - this->abandonPendingResponses(); detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( std::move(Result)); return std::move(Err); @@ -1295,7 +1586,6 @@ public: while (!ReceivedResponse) { if (auto Err = this->handleOne()) { - this->abandonPendingResponses(); detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( std::move(Result)); return std::move(Err); @@ -1306,24 +1596,40 @@ public: } }; +/// Asynchronous dispatch for a function on an RPC endpoint. +template <typename RPCClass, typename Func> +class RPCAsyncDispatch { +public: + RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {} + + template <typename HandlerT, typename... ArgTs> + Error operator()(HandlerT Handler, const ArgTs &... Args) const { + return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...); + } + +private: + RPCClass &Endpoint; +}; + +/// Construct an asynchronous dispatcher from an RPC endpoint and a Func. +template <typename Func, typename RPCEndpointT> +RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) { + return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint); +} + /// \brief Allows a set of asynchrounous calls to be dispatched, and then /// waited on as a group. -template <typename RPCClass> class ParallelCallGroup { +class ParallelCallGroup { public: - /// \brief Construct a parallel call group for the given RPC. - ParallelCallGroup(RPCClass &RPC) : RPC(RPC), NumOutstandingCalls(0) {} - + ParallelCallGroup() = default; ParallelCallGroup(const ParallelCallGroup &) = delete; ParallelCallGroup &operator=(const ParallelCallGroup &) = delete; /// \brief Make as asynchronous call. - /// - /// Does not issue a send call to the RPC's channel. The channel may use this - /// to batch up subsequent calls. A send will automatically be sent when wait - /// is called. - template <typename Func, typename HandlerT, typename... ArgTs> - Error appendCall(HandlerT Handler, const ArgTs &... Args) { + template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs> + Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler, + const ArgTs &... Args) { // Increment the count of outstanding calls. This has to happen before // we invoke the call, as the handler may (depending on scheduling) // be run immediately on another thread, and we don't want the decrement @@ -1346,38 +1652,21 @@ public: return Err; }; - return RPC.template appendCallAsync<Func>(std::move(WrappedHandler), - Args...); - } - - /// \brief Make an asynchronous call. - /// - /// The same as appendCall, but also calls send on the channel immediately. - /// Prefer appendCall if you are about to issue a "wait" call shortly, as - /// this may allow the channel to better batch the calls. - template <typename Func, typename HandlerT, typename... ArgTs> - Error call(HandlerT Handler, const ArgTs &... Args) { - if (auto Err = appendCall(std::move(Handler), Args...)) - return Err; - return RPC.sendAppendedCalls(); + return AsyncDispatch(std::move(WrappedHandler), Args...); } /// \brief Blocks until all calls have been completed and their return value /// handlers run. - Error wait() { - if (auto Err = RPC.sendAppendedCalls()) - return Err; + void wait() { std::unique_lock<std::mutex> Lock(M); while (NumOutstandingCalls > 0) CV.wait(Lock); - return Error::success(); } private: - RPCClass &RPC; std::mutex M; std::condition_variable CV; - uint32_t NumOutstandingCalls; + uint32_t NumOutstandingCalls = 0; }; /// @brief Convenience class for grouping RPC Functions into APIs that can be diff --git a/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h b/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h index 0588d22285981..babcc7f26aab5 100644 --- a/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h +++ b/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h @@ -1,4 +1,4 @@ -//===- ObjectLinkingLayer.h - Add object files to a JIT process -*- C++ -*-===// +//===-- RTDyldObjectLinkingLayer.h - RTDyld-based jit linking --*- C++ -*-===// // // The LLVM Compiler Infrastructure // @@ -7,12 +7,12 @@ // //===----------------------------------------------------------------------===// // -// Contains the definition for the object layer of the JIT. +// Contains the definition for an RTDyld-based, in-process object linking layer. // //===----------------------------------------------------------------------===// -#ifndef LLVM_EXECUTIONENGINE_ORC_OBJECTLINKINGLAYER_H -#define LLVM_EXECUTIONENGINE_ORC_OBJECTLINKINGLAYER_H +#ifndef LLVM_EXECUTIONENGINE_ORC_RTDYLDOBJECTLINKINGLAYER_H +#define LLVM_EXECUTIONENGINE_ORC_RTDYLDOBJECTLINKINGLAYER_H #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringMap.h" @@ -35,7 +35,7 @@ namespace llvm { namespace orc { -class ObjectLinkingLayerBase { +class RTDyldObjectLinkingLayerBase { protected: /// @brief Holds a set of objects to be allocated/linked as a unit in the JIT. /// @@ -87,7 +87,7 @@ public: class DoNothingOnNotifyLoaded { public: template <typename ObjSetT, typename LoadResult> - void operator()(ObjectLinkingLayerBase::ObjSetHandleT, const ObjSetT &, + void operator()(RTDyldObjectLinkingLayerBase::ObjSetHandleT, const ObjSetT &, const LoadResult &) {} }; @@ -98,7 +98,7 @@ public: /// symbols queried. All objects added to this layer can see each other's /// symbols. template <typename NotifyLoadedFtor = DoNothingOnNotifyLoaded> -class ObjectLinkingLayer : public ObjectLinkingLayerBase { +class RTDyldObjectLinkingLayer : public RTDyldObjectLinkingLayerBase { public: /// @brief Functor for receiving finalization notifications. typedef std::function<void(ObjSetHandleT)> NotifyFinalizedFtor; @@ -227,7 +227,7 @@ public: /// @brief Construct an ObjectLinkingLayer with the given NotifyLoaded, /// and NotifyFinalized functors. - ObjectLinkingLayer( + RTDyldObjectLinkingLayer( NotifyLoadedFtor NotifyLoaded = NotifyLoadedFtor(), NotifyFinalizedFtor NotifyFinalized = NotifyFinalizedFtor()) : NotifyLoaded(std::move(NotifyLoaded)), @@ -359,4 +359,4 @@ private: } // end namespace orc } // end namespace llvm -#endif // LLVM_EXECUTIONENGINE_ORC_OBJECTLINKINGLAYER_H +#endif // LLVM_EXECUTIONENGINE_ORC_RTDYLDOBJECTLINKINGLAYER_H diff --git a/include/llvm/ExecutionEngine/Orc/RawByteChannel.h b/include/llvm/ExecutionEngine/Orc/RawByteChannel.h index 3b6c84eb19650..52a546f7c6eb9 100644 --- a/include/llvm/ExecutionEngine/Orc/RawByteChannel.h +++ b/include/llvm/ExecutionEngine/Orc/RawByteChannel.h @@ -48,7 +48,11 @@ public: template <typename FunctionIdT, typename SequenceIdT> Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { writeLock.lock(); - return serializeSeq(*this, FnId, SeqNo); + if (auto Err = serializeSeq(*this, FnId, SeqNo)) { + writeLock.unlock(); + return Err; + } + return Error::success(); } /// Notify the channel that we're ending a message send. @@ -63,7 +67,11 @@ public: template <typename FunctionIdT, typename SequenceNumberT> Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { readLock.lock(); - return deserializeSeq(*this, FnId, SeqNo); + if (auto Err = deserializeSeq(*this, FnId, SeqNo)) { + readLock.unlock(); + return Err; + } + return Error::success(); } /// Notify the channel that we're ending a message receive. @@ -113,11 +121,19 @@ class SerializationTraits<ChannelT, bool, bool, RawByteChannel, ChannelT>::value>::type> { public: static Error serialize(ChannelT &C, bool V) { - return C.appendBytes(reinterpret_cast<const char *>(&V), 1); + uint8_t Tmp = V ? 1 : 0; + if (auto Err = + C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1)) + return Err; + return Error::success(); } static Error deserialize(ChannelT &C, bool &V) { - return C.readBytes(reinterpret_cast<char *>(&V), 1); + uint8_t Tmp = 0; + if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1)) + return Err; + V = Tmp != 0; + return Error::success(); } }; @@ -134,10 +150,12 @@ public: } }; -template <typename ChannelT> -class SerializationTraits<ChannelT, std::string, const char *, - typename std::enable_if<std::is_base_of< - RawByteChannel, ChannelT>::value>::type> { +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, std::string, T, + typename std::enable_if< + std::is_base_of<RawByteChannel, ChannelT>::value && + (std::is_same<T, const char*>::value || + std::is_same<T, char*>::value)>::type> { public: static Error serialize(RawByteChannel &C, const char *S) { return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, |