diff options
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc')
| -rw-r--r-- | include/llvm/ExecutionEngine/Orc/RPCSerialization.h | 145 | 
1 files changed, 139 insertions, 6 deletions
| diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h index 569c50602f3a..1e5f6ced597a 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h +++ b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -14,7 +14,10 @@  #include "llvm/Support/thread.h"  #include <map>  #include <mutex> +#include <set>  #include <sstream> +#include <string> +#include <vector>  namespace llvm {  namespace orc { @@ -205,6 +208,42 @@ std::mutex RPCTypeName<std::vector<T>>::NameMutex;  template <typename T>  std::string RPCTypeName<std::vector<T>>::Name; +template <typename T> class RPCTypeName<std::set<T>> { +public: +  static const char *getName() { +    std::lock_guard<std::mutex> Lock(NameMutex); +    if (Name.empty()) +      raw_string_ostream(Name) +          << "std::set<" << RPCTypeName<T>::getName() << ">"; +    return Name.data(); +  } + +private: +  static std::mutex NameMutex; +  static std::string Name; +}; + +template <typename T> std::mutex RPCTypeName<std::set<T>>::NameMutex; +template <typename T> std::string RPCTypeName<std::set<T>>::Name; + +template <typename K, typename V> class RPCTypeName<std::map<K, V>> { +public: +  static const char *getName() { +    std::lock_guard<std::mutex> Lock(NameMutex); +    if (Name.empty()) +      raw_string_ostream(Name) +          << "std::map<" << RPCTypeNameSequence<K, V>() << ">"; +    return Name.data(); +  } + +private: +  static std::mutex NameMutex; +  static std::string Name; +}; + +template <typename K, typename V> +std::mutex RPCTypeName<std::map<K, V>>::NameMutex; +template <typename K, typename V> std::string RPCTypeName<std::map<K, V>>::Name;  /// The SerializationTraits<ChannelT, T> class describes how to serialize and  /// deserialize an instance of type T to/from an abstract channel of type @@ -527,15 +566,20 @@ public:  };  /// SerializationTraits default specialization for std::pair. -template <typename ChannelT, typename T1, typename T2> -class SerializationTraits<ChannelT, std::pair<T1, T2>> { +template <typename ChannelT, typename T1, typename T2, typename T3, typename T4> +class SerializationTraits<ChannelT, std::pair<T1, T2>, std::pair<T3, T4>> {  public: -  static Error serialize(ChannelT &C, const std::pair<T1, T2> &V) { -    return serializeSeq(C, V.first, V.second); +  static Error serialize(ChannelT &C, const std::pair<T3, T4> &V) { +    if (auto Err = SerializationTraits<ChannelT, T1, T3>::serialize(C, V.first)) +      return Err; +    return SerializationTraits<ChannelT, T2, T4>::serialize(C, V.second);    } -  static Error deserialize(ChannelT &C, std::pair<T1, T2> &V) { -    return deserializeSeq(C, V.first, V.second); +  static Error deserialize(ChannelT &C, std::pair<T3, T4> &V) { +    if (auto Err = +            SerializationTraits<ChannelT, T1, T3>::deserialize(C, V.first)) +      return Err; +    return SerializationTraits<ChannelT, T2, T4>::deserialize(C, V.second);    }  }; @@ -589,6 +633,9 @@ public:    /// Deserialize a std::vector<T> to a std::vector<T>.    static Error deserialize(ChannelT &C, std::vector<T> &V) { +    assert(V.empty() && +           "Expected default-constructed vector to deserialize into"); +      uint64_t Count = 0;      if (auto Err = deserializeSeq(C, Count))        return Err; @@ -602,6 +649,92 @@ public:    }  }; +template <typename ChannelT, typename T, typename T2> +class SerializationTraits<ChannelT, std::set<T>, std::set<T2>> { +public: +  /// Serialize a std::set<T> from std::set<T2>. +  static Error serialize(ChannelT &C, const std::set<T2> &S) { +    if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) +      return Err; + +    for (const auto &E : S) +      if (auto Err = SerializationTraits<ChannelT, T, T2>::serialize(C, E)) +        return Err; + +    return Error::success(); +  } + +  /// Deserialize a std::set<T> to a std::set<T>. +  static Error deserialize(ChannelT &C, std::set<T2> &S) { +    assert(S.empty() && "Expected default-constructed set to deserialize into"); + +    uint64_t Count = 0; +    if (auto Err = deserializeSeq(C, Count)) +      return Err; + +    while (Count-- != 0) { +      T2 Val; +      if (auto Err = SerializationTraits<ChannelT, T, T2>::deserialize(C, Val)) +        return Err; + +      auto Added = S.insert(Val).second; +      if (!Added) +        return make_error<StringError>("Duplicate element in deserialized set", +                                       orcError(OrcErrorCode::UnknownORCError)); +    } + +    return Error::success(); +  } +}; + +template <typename ChannelT, typename K, typename V, typename K2, typename V2> +class SerializationTraits<ChannelT, std::map<K, V>, std::map<K2, V2>> { +public: +  /// Serialize a std::map<K, V> from std::map<K2, V2>. +  static Error serialize(ChannelT &C, const std::map<K2, V2> &M) { +    if (auto Err = serializeSeq(C, static_cast<uint64_t>(M.size()))) +      return Err; + +    for (const auto &E : M) { +      if (auto Err = +              SerializationTraits<ChannelT, K, K2>::serialize(C, E.first)) +        return Err; +      if (auto Err = +              SerializationTraits<ChannelT, V, V2>::serialize(C, E.second)) +        return Err; +    } + +    return Error::success(); +  } + +  /// Deserialize a std::map<K, V> to a std::map<K, V>. +  static Error deserialize(ChannelT &C, std::map<K2, V2> &M) { +    assert(M.empty() && "Expected default-constructed map to deserialize into"); + +    uint64_t Count = 0; +    if (auto Err = deserializeSeq(C, Count)) +      return Err; + +    while (Count-- != 0) { +      std::pair<K2, V2> Val; +      if (auto Err = +              SerializationTraits<ChannelT, K, K2>::deserialize(C, Val.first)) +        return Err; + +      if (auto Err = +              SerializationTraits<ChannelT, V, V2>::deserialize(C, Val.second)) +        return Err; + +      auto Added = M.insert(Val).second; +      if (!Added) +        return make_error<StringError>("Duplicate element in deserialized map", +                                       orcError(OrcErrorCode::UnknownORCError)); +    } + +    return Error::success(); +  } +}; +  } // end namespace rpc  } // end namespace orc  } // end namespace llvm | 
