diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2017-04-16 16:01:22 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2017-04-16 16:01:22 +0000 |
commit | 71d5a2540a98c81f5bcaeb48805e0e2881f530ef (patch) | |
tree | 5343938942df402b49ec7300a1c25a2d4ccd5821 /include/llvm/ExecutionEngine/Orc/RPCSerialization.h | |
parent | 31bbf64f3a4974a2d6c8b3b27ad2f519caf74057 (diff) |
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc/RPCSerialization.h')
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/RPCSerialization.h | 243 |
1 files changed, 234 insertions, 9 deletions
diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h index 359a9d81b22b..84a037b2f998 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h +++ b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -12,6 +12,7 @@ #include "OrcError.h" #include "llvm/Support/thread.h" +#include <map> #include <mutex> #include <sstream> @@ -114,6 +115,35 @@ public: static const char* getName() { return "std::string"; } }; +template <> +class RPCTypeName<Error> { +public: + static const char* getName() { return "Error"; } +}; + +template <typename T> +class RPCTypeName<Expected<T>> { +public: + static const char* getName() { + std::lock_guard<std::mutex> Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) << "Expected<" + << RPCTypeNameSequence<T>() + << ">"; + return Name.data(); + } + +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template <typename T> +std::mutex RPCTypeName<Expected<T>>::NameMutex; + +template <typename T> +std::string RPCTypeName<Expected<T>>::Name; + template <typename T1, typename T2> class RPCTypeName<std::pair<T1, T2>> { public: @@ -243,8 +273,10 @@ class SequenceSerialization<ChannelT, ArgT> { public: template <typename CArgT> - static Error serialize(ChannelT &C, const CArgT &CArg) { - return SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg); + static Error serialize(ChannelT &C, CArgT &&CArg) { + return SerializationTraits<ChannelT, ArgT, + typename std::decay<CArgT>::type>:: + serialize(C, std::forward<CArgT>(CArg)); } template <typename CArgT> @@ -258,19 +290,21 @@ class SequenceSerialization<ChannelT, ArgT, ArgTs...> { public: template <typename CArgT, typename... CArgTs> - static Error serialize(ChannelT &C, const CArgT &CArg, - const CArgTs&... CArgs) { + static Error serialize(ChannelT &C, CArgT &&CArg, + CArgTs &&... CArgs) { if (auto Err = - SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg)) + SerializationTraits<ChannelT, ArgT, typename std::decay<CArgT>::type>:: + serialize(C, std::forward<CArgT>(CArg))) return Err; if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C)) return Err; - return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); + return SequenceSerialization<ChannelT, ArgTs...>:: + serialize(C, std::forward<CArgTs>(CArgs)...); } template <typename CArgT, typename... CArgTs> static Error deserialize(ChannelT &C, CArgT &CArg, - CArgTs&... CArgs) { + CArgTs &... CArgs) { if (auto Err = SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg)) return Err; @@ -281,8 +315,9 @@ public: }; template <typename ChannelT, typename... ArgTs> -Error serializeSeq(ChannelT &C, const ArgTs &... Args) { - return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, Args...); +Error serializeSeq(ChannelT &C, ArgTs &&... Args) { + return SequenceSerialization<ChannelT, typename std::decay<ArgTs>::type...>:: + serialize(C, std::forward<ArgTs>(Args)...); } template <typename ChannelT, typename... ArgTs> @@ -290,6 +325,196 @@ Error deserializeSeq(ChannelT &C, ArgTs &... Args) { return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...); } +template <typename ChannelT> +class SerializationTraits<ChannelT, Error> { +public: + + using WrappedErrorSerializer = + std::function<Error(ChannelT &C, const ErrorInfoBase&)>; + + using WrappedErrorDeserializer = + std::function<Error(ChannelT &C, Error &Err)>; + + template <typename ErrorInfoT, typename SerializeFtor, + typename DeserializeFtor> + static void registerErrorType(std::string Name, SerializeFtor Serialize, + DeserializeFtor Deserialize) { + assert(!Name.empty() && + "The empty string is reserved for the Success value"); + + const std::string *KeyName = nullptr; + { + // We're abusing the stability of std::map here: We take a reference to the + // key of the deserializers map to save us from duplicating the string in + // the serializer. This should be changed to use a stringpool if we switch + // to a map type that may move keys in memory. + std::lock_guard<std::mutex> Lock(DeserializersMutex); + auto I = + Deserializers.insert(Deserializers.begin(), + std::make_pair(std::move(Name), + std::move(Deserialize))); + KeyName = &I->first; + } + + { + assert(KeyName != nullptr && "No keyname pointer"); + std::lock_guard<std::mutex> Lock(SerializersMutex); + // FIXME: Move capture Serialize once we have C++14. + Serializers[ErrorInfoT::classID()] = + [KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error { + assert(EIB.dynamicClassID() == ErrorInfoT::classID() && + "Serializer called for wrong error type"); + if (auto Err = serializeSeq(C, *KeyName)) + return Err; + return Serialize(C, static_cast<const ErrorInfoT&>(EIB)); + }; + } + } + + static Error serialize(ChannelT &C, Error &&Err) { + std::lock_guard<std::mutex> Lock(SerializersMutex); + if (!Err) + return serializeSeq(C, std::string()); + + return handleErrors(std::move(Err), + [&C](const ErrorInfoBase &EIB) { + auto SI = Serializers.find(EIB.dynamicClassID()); + if (SI == Serializers.end()) + return serializeAsStringError(C, EIB); + return (SI->second)(C, EIB); + }); + } + + static Error deserialize(ChannelT &C, Error &Err) { + std::lock_guard<std::mutex> Lock(DeserializersMutex); + + std::string Key; + if (auto Err = deserializeSeq(C, Key)) + return Err; + + if (Key.empty()) { + ErrorAsOutParameter EAO(&Err); + Err = Error::success(); + return Error::success(); + } + + auto DI = Deserializers.find(Key); + assert(DI != Deserializers.end() && "No deserializer for error type"); + return (DI->second)(C, Err); + } + +private: + + static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { + assert(EIB.dynamicClassID() != StringError::classID() && + "StringError serialization not registered"); + std::string ErrMsg; + { + raw_string_ostream ErrMsgStream(ErrMsg); + EIB.log(ErrMsgStream); + } + return serialize(C, make_error<StringError>(std::move(ErrMsg), + inconvertibleErrorCode())); + } + + static std::mutex SerializersMutex; + static std::mutex DeserializersMutex; + static std::map<const void*, WrappedErrorSerializer> Serializers; + static std::map<std::string, WrappedErrorDeserializer> Deserializers; +}; + +template <typename ChannelT> +std::mutex SerializationTraits<ChannelT, Error>::SerializersMutex; + +template <typename ChannelT> +std::mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; + +template <typename ChannelT> +std::map<const void*, + typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer> +SerializationTraits<ChannelT, Error>::Serializers; + +template <typename ChannelT> +std::map<std::string, + typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer> +SerializationTraits<ChannelT, Error>::Deserializers; + +template <typename ChannelT> +void registerStringError() { + static bool AlreadyRegistered = false; + if (!AlreadyRegistered) { + SerializationTraits<ChannelT, Error>:: + template registerErrorType<StringError>( + "StringError", + [](ChannelT &C, const StringError &SE) { + return serializeSeq(C, SE.getMessage()); + }, + [](ChannelT &C, Error &Err) { + ErrorAsOutParameter EAO(&Err); + std::string Msg; + if (auto E2 = deserializeSeq(C, Msg)) + return E2; + Err = + make_error<StringError>(std::move(Msg), + orcError( + OrcErrorCode::UnknownErrorCodeFromRemote)); + return Error::success(); + }); + AlreadyRegistered = true; + } +} + +/// SerializationTraits for Expected<T1> from an Expected<T2>. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> { +public: + + static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) { + if (ValOrErr) { + if (auto Err = serializeSeq(C, true)) + return Err; + return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr); + } + if (auto Err = serializeSeq(C, false)) + return Err; + return serializeSeq(C, ValOrErr.takeError()); + } + + static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) { + ExpectedAsOutParameter<T2> EAO(&ValOrErr); + bool HasValue; + if (auto Err = deserializeSeq(C, HasValue)) + return Err; + if (HasValue) + return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr); + Error Err = Error::success(); + if (auto E2 = deserializeSeq(C, Err)) + return E2; + ValOrErr = std::move(Err); + return Error::success(); + } +}; + +/// SerializationTraits for Expected<T1> from a T2. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, T2> { +public: + + static Error serialize(ChannelT &C, T2 &&Val) { + return serializeSeq(C, Expected<T2>(std::forward<T2>(Val))); + } +}; + +/// SerializationTraits for Expected<T1> from an Error. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, Expected<T>, Error> { +public: + + static Error serialize(ChannelT &C, Error &&Err) { + return serializeSeq(C, Expected<T>(std::move(Err))); + } +}; + /// SerializationTraits default specialization for std::pair. template <typename ChannelT, typename T1, typename T2> class SerializationTraits<ChannelT, std::pair<T1, T2>> { |