diff options
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc/RPCSerialization.h')
-rw-r--r-- | include/llvm/ExecutionEngine/Orc/RPCSerialization.h | 65 |
1 files changed, 38 insertions, 27 deletions
diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h index 84a037b2f998b..a3be242b44577 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h +++ b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -348,7 +348,7 @@ public: // key of the deserializers map to save us from duplicating the string in // the serializer. This should be changed to use a stringpool if we switch // to a map type that may move keys in memory. - std::lock_guard<std::mutex> Lock(DeserializersMutex); + std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); auto I = Deserializers.insert(Deserializers.begin(), std::make_pair(std::move(Name), @@ -358,7 +358,7 @@ public: { assert(KeyName != nullptr && "No keyname pointer"); - std::lock_guard<std::mutex> Lock(SerializersMutex); + std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); // FIXME: Move capture Serialize once we have C++14. Serializers[ErrorInfoT::classID()] = [KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error { @@ -372,7 +372,8 @@ public: } static Error serialize(ChannelT &C, Error &&Err) { - std::lock_guard<std::mutex> Lock(SerializersMutex); + std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); + if (!Err) return serializeSeq(C, std::string()); @@ -386,7 +387,7 @@ public: } static Error deserialize(ChannelT &C, Error &Err) { - std::lock_guard<std::mutex> Lock(DeserializersMutex); + std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); std::string Key; if (auto Err = deserializeSeq(C, Key)) @@ -406,8 +407,6 @@ public: private: static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { - assert(EIB.dynamicClassID() != StringError::classID() && - "StringError serialization not registered"); std::string ErrMsg; { raw_string_ostream ErrMsgStream(ErrMsg); @@ -417,17 +416,17 @@ private: inconvertibleErrorCode())); } - static std::mutex SerializersMutex; - static std::mutex DeserializersMutex; + static std::recursive_mutex SerializersMutex; + static std::recursive_mutex DeserializersMutex; static std::map<const void*, WrappedErrorSerializer> Serializers; static std::map<std::string, WrappedErrorDeserializer> Deserializers; }; template <typename ChannelT> -std::mutex SerializationTraits<ChannelT, Error>::SerializersMutex; +std::recursive_mutex SerializationTraits<ChannelT, Error>::SerializersMutex; template <typename ChannelT> -std::mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; +std::recursive_mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; template <typename ChannelT> std::map<const void*, @@ -439,27 +438,39 @@ std::map<std::string, typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer> SerializationTraits<ChannelT, Error>::Deserializers; +/// Registers a serializer and deserializer for the given error type on the +/// given channel type. +template <typename ChannelT, typename ErrorInfoT, typename SerializeFtor, + typename DeserializeFtor> +void registerErrorSerialization(std::string Name, SerializeFtor &&Serialize, + DeserializeFtor &&Deserialize) { + SerializationTraits<ChannelT, Error>::template registerErrorType<ErrorInfoT>( + std::move(Name), + std::forward<SerializeFtor>(Serialize), + std::forward<DeserializeFtor>(Deserialize)); +} + +/// Registers serialization/deserialization for StringError. template <typename ChannelT> void registerStringError() { static bool AlreadyRegistered = false; if (!AlreadyRegistered) { - SerializationTraits<ChannelT, Error>:: - template registerErrorType<StringError>( - "StringError", - [](ChannelT &C, const StringError &SE) { - return serializeSeq(C, SE.getMessage()); - }, - [](ChannelT &C, Error &Err) { - ErrorAsOutParameter EAO(&Err); - std::string Msg; - if (auto E2 = deserializeSeq(C, Msg)) - return E2; - Err = - make_error<StringError>(std::move(Msg), - orcError( - OrcErrorCode::UnknownErrorCodeFromRemote)); - return Error::success(); - }); + registerErrorSerialization<ChannelT, StringError>( + "StringError", + [](ChannelT &C, const StringError &SE) { + return serializeSeq(C, SE.getMessage()); + }, + [](ChannelT &C, Error &Err) -> Error { + ErrorAsOutParameter EAO(&Err); + std::string Msg; + if (auto E2 = deserializeSeq(C, Msg)) + return E2; + Err = + make_error<StringError>(std::move(Msg), + orcError( + OrcErrorCode::UnknownErrorCodeFromRemote)); + return Error::success(); + }); AlreadyRegistered = true; } } |