diff options
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc/RPCSerialization.h')
| -rw-r--r-- | include/llvm/ExecutionEngine/Orc/RPCSerialization.h | 243 | 
1 files changed, 234 insertions, 9 deletions
| diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h index 359a9d81b22b..84a037b2f998 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h +++ b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -12,6 +12,7 @@  #include "OrcError.h"  #include "llvm/Support/thread.h" +#include <map>  #include <mutex>  #include <sstream> @@ -114,6 +115,35 @@ public:    static const char* getName() { return "std::string"; }  }; +template <> +class RPCTypeName<Error> { +public: +  static const char* getName() { return "Error"; } +}; + +template <typename T> +class RPCTypeName<Expected<T>> { +public: +  static const char* getName() { +    std::lock_guard<std::mutex> Lock(NameMutex); +    if (Name.empty()) +      raw_string_ostream(Name) << "Expected<" +                               << RPCTypeNameSequence<T>() +                               << ">"; +    return Name.data(); +  } + +private: +  static std::mutex NameMutex; +  static std::string Name; +}; + +template <typename T> +std::mutex RPCTypeName<Expected<T>>::NameMutex; + +template <typename T> +std::string RPCTypeName<Expected<T>>::Name; +  template <typename T1, typename T2>  class RPCTypeName<std::pair<T1, T2>> {  public: @@ -243,8 +273,10 @@ class SequenceSerialization<ChannelT, ArgT> {  public:    template <typename CArgT> -  static Error serialize(ChannelT &C, const CArgT &CArg) { -    return SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg); +  static Error serialize(ChannelT &C, CArgT &&CArg) { +    return SerializationTraits<ChannelT, ArgT, +                               typename std::decay<CArgT>::type>:: +             serialize(C, std::forward<CArgT>(CArg));    }    template <typename CArgT> @@ -258,19 +290,21 @@ class SequenceSerialization<ChannelT, ArgT, ArgTs...> {  public:    template <typename CArgT, typename... CArgTs> -  static Error serialize(ChannelT &C, const CArgT &CArg, -                         const CArgTs&... CArgs) { +  static Error serialize(ChannelT &C, CArgT &&CArg, +                         CArgTs &&... CArgs) {      if (auto Err = -        SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg)) +        SerializationTraits<ChannelT, ArgT, typename std::decay<CArgT>::type>:: +          serialize(C, std::forward<CArgT>(CArg)))        return Err;      if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C))        return Err; -    return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); +    return SequenceSerialization<ChannelT, ArgTs...>:: +             serialize(C, std::forward<CArgTs>(CArgs)...);    }    template <typename CArgT, typename... CArgTs>    static Error deserialize(ChannelT &C, CArgT &CArg, -                           CArgTs&... CArgs) { +                           CArgTs &... CArgs) {      if (auto Err =          SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg))        return Err; @@ -281,8 +315,9 @@ public:  };  template <typename ChannelT, typename... ArgTs> -Error serializeSeq(ChannelT &C, const ArgTs &... Args) { -  return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, Args...); +Error serializeSeq(ChannelT &C, ArgTs &&... Args) { +  return SequenceSerialization<ChannelT, typename std::decay<ArgTs>::type...>:: +           serialize(C, std::forward<ArgTs>(Args)...);  }  template <typename ChannelT, typename... ArgTs> @@ -290,6 +325,196 @@ Error deserializeSeq(ChannelT &C, ArgTs &... Args) {    return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...);  } +template <typename ChannelT> +class SerializationTraits<ChannelT, Error> { +public: + +  using WrappedErrorSerializer = +    std::function<Error(ChannelT &C, const ErrorInfoBase&)>; + +  using WrappedErrorDeserializer = +    std::function<Error(ChannelT &C, Error &Err)>; + +  template <typename ErrorInfoT, typename SerializeFtor, +            typename DeserializeFtor> +  static void registerErrorType(std::string Name, SerializeFtor Serialize, +                                DeserializeFtor Deserialize) { +    assert(!Name.empty() && +           "The empty string is reserved for the Success value"); + +    const std::string *KeyName = nullptr; +    { +      // We're abusing the stability of std::map here: We take a reference to the +      // key of the deserializers map to save us from duplicating the string in +      // the serializer. This should be changed to use a stringpool if we switch +      // to a map type that may move keys in memory. +      std::lock_guard<std::mutex> Lock(DeserializersMutex); +      auto I = +        Deserializers.insert(Deserializers.begin(), +                             std::make_pair(std::move(Name), +                                            std::move(Deserialize))); +      KeyName = &I->first; +    } +     +    { +      assert(KeyName != nullptr && "No keyname pointer"); +      std::lock_guard<std::mutex> Lock(SerializersMutex); +      // FIXME: Move capture Serialize once we have C++14. +      Serializers[ErrorInfoT::classID()] = +	[KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error { +          assert(EIB.dynamicClassID() == ErrorInfoT::classID() && +		 "Serializer called for wrong error type"); +	  if (auto Err = serializeSeq(C, *KeyName)) +	    return Err; +	  return Serialize(C, static_cast<const ErrorInfoT&>(EIB)); +        }; +    } +  } +   +  static Error serialize(ChannelT &C, Error &&Err) { +    std::lock_guard<std::mutex> Lock(SerializersMutex); +    if (!Err) +      return serializeSeq(C, std::string()); + +    return handleErrors(std::move(Err), +                        [&C](const ErrorInfoBase &EIB) { +                          auto SI = Serializers.find(EIB.dynamicClassID()); +                          if (SI == Serializers.end()) +                            return serializeAsStringError(C, EIB); +                          return (SI->second)(C, EIB); +                        }); +  } + +  static Error deserialize(ChannelT &C, Error &Err) { +    std::lock_guard<std::mutex> Lock(DeserializersMutex); + +    std::string Key; +    if (auto Err = deserializeSeq(C, Key)) +      return Err; + +    if (Key.empty()) { +      ErrorAsOutParameter EAO(&Err); +      Err = Error::success(); +      return Error::success(); +    } + +    auto DI = Deserializers.find(Key); +    assert(DI != Deserializers.end() && "No deserializer for error type"); +    return (DI->second)(C, Err); +  } + +private: + +  static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { +    assert(EIB.dynamicClassID() != StringError::classID() && +           "StringError serialization not registered"); +    std::string ErrMsg; +    { +      raw_string_ostream ErrMsgStream(ErrMsg); +      EIB.log(ErrMsgStream); +    } +    return serialize(C, make_error<StringError>(std::move(ErrMsg), +                                                inconvertibleErrorCode())); +  } + +  static std::mutex SerializersMutex; +  static std::mutex DeserializersMutex; +  static std::map<const void*, WrappedErrorSerializer> Serializers; +  static std::map<std::string, WrappedErrorDeserializer> Deserializers; +}; + +template <typename ChannelT> +std::mutex SerializationTraits<ChannelT, Error>::SerializersMutex; + +template <typename ChannelT> +std::mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; + +template <typename ChannelT> +std::map<const void*, +         typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer> +SerializationTraits<ChannelT, Error>::Serializers; + +template <typename ChannelT> +std::map<std::string, +         typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer> +SerializationTraits<ChannelT, Error>::Deserializers; + +template <typename ChannelT> +void registerStringError() { +  static bool AlreadyRegistered = false; +  if (!AlreadyRegistered) { +    SerializationTraits<ChannelT, Error>:: +      template registerErrorType<StringError>( +        "StringError", +        [](ChannelT &C, const StringError &SE) { +          return serializeSeq(C, SE.getMessage()); +        }, +        [](ChannelT &C, Error &Err) { +          ErrorAsOutParameter EAO(&Err); +          std::string Msg; +          if (auto E2 = deserializeSeq(C, Msg)) +            return E2; +          Err = +            make_error<StringError>(std::move(Msg), +                                    orcError( +                                      OrcErrorCode::UnknownErrorCodeFromRemote)); +          return Error::success(); +        }); +    AlreadyRegistered = true; +  } +} + +/// SerializationTraits for Expected<T1> from an Expected<T2>. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> { +public: + +  static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) { +    if (ValOrErr) { +      if (auto Err = serializeSeq(C, true)) +        return Err; +      return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr); +    } +    if (auto Err = serializeSeq(C, false)) +      return Err; +    return serializeSeq(C, ValOrErr.takeError()); +  } + +  static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) { +    ExpectedAsOutParameter<T2> EAO(&ValOrErr); +    bool HasValue; +    if (auto Err = deserializeSeq(C, HasValue)) +      return Err; +    if (HasValue) +      return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr); +    Error Err = Error::success(); +    if (auto E2 = deserializeSeq(C, Err)) +      return E2; +    ValOrErr = std::move(Err); +    return Error::success(); +  } +}; + +/// SerializationTraits for Expected<T1> from a T2. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, T2> { +public: + +  static Error serialize(ChannelT &C, T2 &&Val) { +    return serializeSeq(C, Expected<T2>(std::forward<T2>(Val))); +  } +}; + +/// SerializationTraits for Expected<T1> from an Error. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, Expected<T>, Error> { +public: + +  static Error serialize(ChannelT &C, Error &&Err) { +    return serializeSeq(C, Expected<T>(std::move(Err))); +  } +}; +  /// SerializationTraits default specialization for std::pair.  template <typename ChannelT, typename T1, typename T2>  class SerializationTraits<ChannelT, std::pair<T1, T2>> { | 
