diff options
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc')
| -rw-r--r-- | include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h | 12 | ||||
| -rw-r--r-- | include/llvm/ExecutionEngine/Orc/OrcError.h | 6 | ||||
| -rw-r--r-- | include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h | 1 | ||||
| -rw-r--r-- | include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h | 23 | ||||
| -rw-r--r-- | include/llvm/ExecutionEngine/Orc/RPCSerialization.h | 243 | ||||
| -rw-r--r-- | include/llvm/ExecutionEngine/Orc/RPCUtils.h | 787 | ||||
| -rw-r--r-- | include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h (renamed from include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h) | 18 | ||||
| -rw-r--r-- | include/llvm/ExecutionEngine/Orc/RawByteChannel.h | 34 | 
8 files changed, 833 insertions, 291 deletions
| diff --git a/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h b/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h index aa096478cd9e..7e7f7358938a 100644 --- a/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h +++ b/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h @@ -376,7 +376,7 @@ private:      // Initializers may refer to functions declared (but not defined) in this      // module. Build a materializer to clone decls on demand.      auto Materializer = createLambdaMaterializer( -      [this, &LD, &GVsM](Value *V) -> Value* { +      [&LD, &GVsM](Value *V) -> Value* {          if (auto *F = dyn_cast<Function>(V)) {            // Decls in the original module just get cloned.            if (F->isDeclaration()) @@ -419,7 +419,7 @@ private:      // Build a resolver for the globals module and add it to the base layer.      auto GVsResolver = createLambdaResolver( -        [this, &LD, LMId](const std::string &Name) { +        [this, &LD](const std::string &Name) {            if (auto Sym = LD.StubsMgr->findStub(Name, false))              return Sym;            if (auto Sym = LD.findSymbol(BaseLayer, Name, false)) @@ -499,8 +499,8 @@ private:      M->setDataLayout(SrcM.getDataLayout());      ValueToValueMapTy VMap; -    auto Materializer = createLambdaMaterializer([this, &LD, &LMId, &M, -                                                  &VMap](Value *V) -> Value * { +    auto Materializer = createLambdaMaterializer([&LD, &LMId, +                                                  &M](Value *V) -> Value * {        if (auto *GV = dyn_cast<GlobalVariable>(V))          return cloneGlobalVariableDecl(*M, *GV); @@ -546,12 +546,12 @@ private:      // Create memory manager and symbol resolver.      auto Resolver = createLambdaResolver( -        [this, &LD, LMId](const std::string &Name) { +        [this, &LD](const std::string &Name) {            if (auto Sym = LD.findSymbol(BaseLayer, Name, false))              return Sym;            return LD.ExternalSymbolResolver->findSymbolInLogicalDylib(Name);          }, -        [this, &LD](const std::string &Name) { +        [&LD](const std::string &Name) {            return LD.ExternalSymbolResolver->findSymbol(Name);          }); diff --git a/include/llvm/ExecutionEngine/Orc/OrcError.h b/include/llvm/ExecutionEngine/Orc/OrcError.h index b74988cce2fb..cbb40fad0223 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcError.h +++ b/include/llvm/ExecutionEngine/Orc/OrcError.h @@ -27,13 +27,15 @@ enum class OrcErrorCode : int {    RemoteMProtectAddrUnrecognized,    RemoteIndirectStubsOwnerDoesNotExist,    RemoteIndirectStubsOwnerIdAlreadyInUse, +  RPCConnectionClosed, +  RPCCouldNotNegotiateFunction,    RPCResponseAbandoned,    UnexpectedRPCCall,    UnexpectedRPCResponse, -  UnknownRPCFunction +  UnknownErrorCodeFromRemote  }; -Error orcError(OrcErrorCode ErrCode); +std::error_code orcError(OrcErrorCode ErrCode);  } // End namespace orc.  } // End namespace llvm. diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h index 8647db56cd2f..02f59d6a831a 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h @@ -18,6 +18,7 @@  #include "IndirectionUtils.h"  #include "OrcRemoteTargetRPCAPI.h" +#include "llvm/ExecutionEngine/RuntimeDyld.h"  #include <system_error>  #define DEBUG_TYPE "orc-remote" diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h index 506330fe3a5e..a61ff102be0b 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h @@ -132,7 +132,7 @@ private:      Error setProtections(void *block, unsigned Flags) {        auto I = Allocs.find(block);        if (I == Allocs.end()) -        return orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized); +        return errorCodeToError(orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized));        return errorCodeToError(            sys::Memory::protectMappedMemory(I->second, Flags));      } @@ -198,7 +198,8 @@ private:    Error handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) {      auto I = Allocators.find(Id);      if (I != Allocators.end()) -      return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse); +      return errorCodeToError( +               orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse));      DEBUG(dbgs() << "  Created allocator " << Id << "\n");      Allocators[Id] = Allocator();      return Error::success(); @@ -207,7 +208,8 @@ private:    Error handleCreateIndirectStubsOwner(ResourceIdMgr::ResourceId Id) {      auto I = IndirectStubsOwners.find(Id);      if (I != IndirectStubsOwners.end()) -      return orcError(OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse); +      return errorCodeToError( +               orcError(OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse));      DEBUG(dbgs() << "  Create indirect stubs owner " << Id << "\n");      IndirectStubsOwners[Id] = ISBlockOwnerList();      return Error::success(); @@ -224,7 +226,8 @@ private:    Error handleDestroyRemoteAllocator(ResourceIdMgr::ResourceId Id) {      auto I = Allocators.find(Id);      if (I == Allocators.end()) -      return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); +      return errorCodeToError( +               orcError(OrcErrorCode::RemoteAllocatorDoesNotExist));      Allocators.erase(I);      DEBUG(dbgs() << "  Destroyed allocator " << Id << "\n");      return Error::success(); @@ -233,7 +236,8 @@ private:    Error handleDestroyIndirectStubsOwner(ResourceIdMgr::ResourceId Id) {      auto I = IndirectStubsOwners.find(Id);      if (I == IndirectStubsOwners.end()) -      return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist); +      return errorCodeToError( +               orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist));      IndirectStubsOwners.erase(I);      return Error::success();    } @@ -246,7 +250,8 @@ private:      auto StubOwnerItr = IndirectStubsOwners.find(Id);      if (StubOwnerItr == IndirectStubsOwners.end()) -      return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist); +      return errorCodeToError( +               orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist));      typename TargetT::IndirectStubsInfo IS;      if (auto Err = @@ -361,7 +366,8 @@ private:                                                uint64_t Size, uint32_t Align) {      auto I = Allocators.find(Id);      if (I == Allocators.end()) -      return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); +      return errorCodeToError( +               orcError(OrcErrorCode::RemoteAllocatorDoesNotExist));      auto &Allocator = I->second;      void *LocalAllocAddr = nullptr;      if (auto Err = Allocator.allocate(LocalAllocAddr, Size, Align)) @@ -380,7 +386,8 @@ private:                               JITTargetAddress Addr, uint32_t Flags) {      auto I = Allocators.find(Id);      if (I == Allocators.end()) -      return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); +      return errorCodeToError( +               orcError(OrcErrorCode::RemoteAllocatorDoesNotExist));      auto &Allocator = I->second;      void *LocalAddr = reinterpret_cast<void *>(static_cast<uintptr_t>(Addr));      DEBUG(dbgs() << "  Allocator " << Id << " set permissions on " << LocalAddr diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h index 359a9d81b22b..84a037b2f998 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h +++ b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -12,6 +12,7 @@  #include "OrcError.h"  #include "llvm/Support/thread.h" +#include <map>  #include <mutex>  #include <sstream> @@ -114,6 +115,35 @@ public:    static const char* getName() { return "std::string"; }  }; +template <> +class RPCTypeName<Error> { +public: +  static const char* getName() { return "Error"; } +}; + +template <typename T> +class RPCTypeName<Expected<T>> { +public: +  static const char* getName() { +    std::lock_guard<std::mutex> Lock(NameMutex); +    if (Name.empty()) +      raw_string_ostream(Name) << "Expected<" +                               << RPCTypeNameSequence<T>() +                               << ">"; +    return Name.data(); +  } + +private: +  static std::mutex NameMutex; +  static std::string Name; +}; + +template <typename T> +std::mutex RPCTypeName<Expected<T>>::NameMutex; + +template <typename T> +std::string RPCTypeName<Expected<T>>::Name; +  template <typename T1, typename T2>  class RPCTypeName<std::pair<T1, T2>> {  public: @@ -243,8 +273,10 @@ class SequenceSerialization<ChannelT, ArgT> {  public:    template <typename CArgT> -  static Error serialize(ChannelT &C, const CArgT &CArg) { -    return SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg); +  static Error serialize(ChannelT &C, CArgT &&CArg) { +    return SerializationTraits<ChannelT, ArgT, +                               typename std::decay<CArgT>::type>:: +             serialize(C, std::forward<CArgT>(CArg));    }    template <typename CArgT> @@ -258,19 +290,21 @@ class SequenceSerialization<ChannelT, ArgT, ArgTs...> {  public:    template <typename CArgT, typename... CArgTs> -  static Error serialize(ChannelT &C, const CArgT &CArg, -                         const CArgTs&... CArgs) { +  static Error serialize(ChannelT &C, CArgT &&CArg, +                         CArgTs &&... CArgs) {      if (auto Err = -        SerializationTraits<ChannelT, ArgT, CArgT>::serialize(C, CArg)) +        SerializationTraits<ChannelT, ArgT, typename std::decay<CArgT>::type>:: +          serialize(C, std::forward<CArgT>(CArg)))        return Err;      if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C))        return Err; -    return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, CArgs...); +    return SequenceSerialization<ChannelT, ArgTs...>:: +             serialize(C, std::forward<CArgTs>(CArgs)...);    }    template <typename CArgT, typename... CArgTs>    static Error deserialize(ChannelT &C, CArgT &CArg, -                           CArgTs&... CArgs) { +                           CArgTs &... CArgs) {      if (auto Err =          SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg))        return Err; @@ -281,8 +315,9 @@ public:  };  template <typename ChannelT, typename... ArgTs> -Error serializeSeq(ChannelT &C, const ArgTs &... Args) { -  return SequenceSerialization<ChannelT, ArgTs...>::serialize(C, Args...); +Error serializeSeq(ChannelT &C, ArgTs &&... Args) { +  return SequenceSerialization<ChannelT, typename std::decay<ArgTs>::type...>:: +           serialize(C, std::forward<ArgTs>(Args)...);  }  template <typename ChannelT, typename... ArgTs> @@ -290,6 +325,196 @@ Error deserializeSeq(ChannelT &C, ArgTs &... Args) {    return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...);  } +template <typename ChannelT> +class SerializationTraits<ChannelT, Error> { +public: + +  using WrappedErrorSerializer = +    std::function<Error(ChannelT &C, const ErrorInfoBase&)>; + +  using WrappedErrorDeserializer = +    std::function<Error(ChannelT &C, Error &Err)>; + +  template <typename ErrorInfoT, typename SerializeFtor, +            typename DeserializeFtor> +  static void registerErrorType(std::string Name, SerializeFtor Serialize, +                                DeserializeFtor Deserialize) { +    assert(!Name.empty() && +           "The empty string is reserved for the Success value"); + +    const std::string *KeyName = nullptr; +    { +      // We're abusing the stability of std::map here: We take a reference to the +      // key of the deserializers map to save us from duplicating the string in +      // the serializer. This should be changed to use a stringpool if we switch +      // to a map type that may move keys in memory. +      std::lock_guard<std::mutex> Lock(DeserializersMutex); +      auto I = +        Deserializers.insert(Deserializers.begin(), +                             std::make_pair(std::move(Name), +                                            std::move(Deserialize))); +      KeyName = &I->first; +    } +     +    { +      assert(KeyName != nullptr && "No keyname pointer"); +      std::lock_guard<std::mutex> Lock(SerializersMutex); +      // FIXME: Move capture Serialize once we have C++14. +      Serializers[ErrorInfoT::classID()] = +	[KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error { +          assert(EIB.dynamicClassID() == ErrorInfoT::classID() && +		 "Serializer called for wrong error type"); +	  if (auto Err = serializeSeq(C, *KeyName)) +	    return Err; +	  return Serialize(C, static_cast<const ErrorInfoT&>(EIB)); +        }; +    } +  } +   +  static Error serialize(ChannelT &C, Error &&Err) { +    std::lock_guard<std::mutex> Lock(SerializersMutex); +    if (!Err) +      return serializeSeq(C, std::string()); + +    return handleErrors(std::move(Err), +                        [&C](const ErrorInfoBase &EIB) { +                          auto SI = Serializers.find(EIB.dynamicClassID()); +                          if (SI == Serializers.end()) +                            return serializeAsStringError(C, EIB); +                          return (SI->second)(C, EIB); +                        }); +  } + +  static Error deserialize(ChannelT &C, Error &Err) { +    std::lock_guard<std::mutex> Lock(DeserializersMutex); + +    std::string Key; +    if (auto Err = deserializeSeq(C, Key)) +      return Err; + +    if (Key.empty()) { +      ErrorAsOutParameter EAO(&Err); +      Err = Error::success(); +      return Error::success(); +    } + +    auto DI = Deserializers.find(Key); +    assert(DI != Deserializers.end() && "No deserializer for error type"); +    return (DI->second)(C, Err); +  } + +private: + +  static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { +    assert(EIB.dynamicClassID() != StringError::classID() && +           "StringError serialization not registered"); +    std::string ErrMsg; +    { +      raw_string_ostream ErrMsgStream(ErrMsg); +      EIB.log(ErrMsgStream); +    } +    return serialize(C, make_error<StringError>(std::move(ErrMsg), +                                                inconvertibleErrorCode())); +  } + +  static std::mutex SerializersMutex; +  static std::mutex DeserializersMutex; +  static std::map<const void*, WrappedErrorSerializer> Serializers; +  static std::map<std::string, WrappedErrorDeserializer> Deserializers; +}; + +template <typename ChannelT> +std::mutex SerializationTraits<ChannelT, Error>::SerializersMutex; + +template <typename ChannelT> +std::mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; + +template <typename ChannelT> +std::map<const void*, +         typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer> +SerializationTraits<ChannelT, Error>::Serializers; + +template <typename ChannelT> +std::map<std::string, +         typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer> +SerializationTraits<ChannelT, Error>::Deserializers; + +template <typename ChannelT> +void registerStringError() { +  static bool AlreadyRegistered = false; +  if (!AlreadyRegistered) { +    SerializationTraits<ChannelT, Error>:: +      template registerErrorType<StringError>( +        "StringError", +        [](ChannelT &C, const StringError &SE) { +          return serializeSeq(C, SE.getMessage()); +        }, +        [](ChannelT &C, Error &Err) { +          ErrorAsOutParameter EAO(&Err); +          std::string Msg; +          if (auto E2 = deserializeSeq(C, Msg)) +            return E2; +          Err = +            make_error<StringError>(std::move(Msg), +                                    orcError( +                                      OrcErrorCode::UnknownErrorCodeFromRemote)); +          return Error::success(); +        }); +    AlreadyRegistered = true; +  } +} + +/// SerializationTraits for Expected<T1> from an Expected<T2>. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> { +public: + +  static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) { +    if (ValOrErr) { +      if (auto Err = serializeSeq(C, true)) +        return Err; +      return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr); +    } +    if (auto Err = serializeSeq(C, false)) +      return Err; +    return serializeSeq(C, ValOrErr.takeError()); +  } + +  static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) { +    ExpectedAsOutParameter<T2> EAO(&ValOrErr); +    bool HasValue; +    if (auto Err = deserializeSeq(C, HasValue)) +      return Err; +    if (HasValue) +      return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr); +    Error Err = Error::success(); +    if (auto E2 = deserializeSeq(C, Err)) +      return E2; +    ValOrErr = std::move(Err); +    return Error::success(); +  } +}; + +/// SerializationTraits for Expected<T1> from a T2. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, T2> { +public: + +  static Error serialize(ChannelT &C, T2 &&Val) { +    return serializeSeq(C, Expected<T2>(std::forward<T2>(Val))); +  } +}; + +/// SerializationTraits for Expected<T1> from an Error. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, Expected<T>, Error> { +public: + +  static Error serialize(ChannelT &C, Error &&Err) { +    return serializeSeq(C, Expected<T>(std::move(Err))); +  } +}; +  /// SerializationTraits default specialization for std::pair.  template <typename ChannelT, typename T1, typename T2>  class SerializationTraits<ChannelT, std::pair<T1, T2>> { diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 37e2e66e5af4..6212f64ff319 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -26,27 +26,115 @@  #include "llvm/ExecutionEngine/Orc/OrcError.h"  #include "llvm/ExecutionEngine/Orc/RPCSerialization.h" -#ifdef _MSC_VER -// concrt.h depends on eh.h for __uncaught_exception declaration -// even if we disable exceptions. -#include <eh.h> - -// Disable warnings from ppltasks.h transitively included by <future>. -#pragma warning(push) -#pragma warning(disable : 4530) -#pragma warning(disable : 4062) -#endif -  #include <future> -#ifdef _MSC_VER -#pragma warning(pop) -#endif -  namespace llvm {  namespace orc {  namespace rpc { +/// Base class of all fatal RPC errors (those that necessarily result in the +/// termination of the RPC session). +class RPCFatalError : public ErrorInfo<RPCFatalError> { +public: +  static char ID; +}; + +/// RPCConnectionClosed is returned from RPC operations if the RPC connection +/// has already been closed due to either an error or graceful disconnection. +class ConnectionClosed : public ErrorInfo<ConnectionClosed> { +public: +  static char ID; +  std::error_code convertToErrorCode() const override; +  void log(raw_ostream &OS) const override; +}; + +/// BadFunctionCall is returned from handleOne when the remote makes a call with +/// an unrecognized function id. +/// +/// This error is fatal because Orc RPC needs to know how to parse a function +/// call to know where the next call starts, and if it doesn't recognize the +/// function id it cannot parse the call. +template <typename FnIdT, typename SeqNoT> +class BadFunctionCall +  : public ErrorInfo<BadFunctionCall<FnIdT, SeqNoT>, RPCFatalError> { +public: +  static char ID; + +  BadFunctionCall(FnIdT FnId, SeqNoT SeqNo) +      : FnId(std::move(FnId)), SeqNo(std::move(SeqNo)) {} + +  std::error_code convertToErrorCode() const override { +    return orcError(OrcErrorCode::UnexpectedRPCCall); +  } + +  void log(raw_ostream &OS) const override { +    OS << "Call to invalid RPC function id '" << FnId << "' with " +          "sequence number " << SeqNo; +  } + +private: +  FnIdT FnId; +  SeqNoT SeqNo; +}; + +template <typename FnIdT, typename SeqNoT> +char BadFunctionCall<FnIdT, SeqNoT>::ID = 0; + +/// InvalidSequenceNumberForResponse is returned from handleOne when a response +/// call arrives with a sequence number that doesn't correspond to any in-flight +/// function call. +/// +/// This error is fatal because Orc RPC needs to know how to parse the rest of +/// the response call to know where the next call starts, and if it doesn't have +/// a result parser for this sequence number it can't do that. +template <typename SeqNoT> +class InvalidSequenceNumberForResponse +    : public ErrorInfo<InvalidSequenceNumberForResponse<SeqNoT>, RPCFatalError> { +public: +  static char ID; + +  InvalidSequenceNumberForResponse(SeqNoT SeqNo) +      : SeqNo(std::move(SeqNo)) {} + +  std::error_code convertToErrorCode() const override { +    return orcError(OrcErrorCode::UnexpectedRPCCall); +  }; + +  void log(raw_ostream &OS) const override { +    OS << "Response has unknown sequence number " << SeqNo; +  } +private: +  SeqNoT SeqNo; +}; + +template <typename SeqNoT> +char InvalidSequenceNumberForResponse<SeqNoT>::ID = 0; + +/// This non-fatal error will be passed to asynchronous result handlers in place +/// of a result if the connection goes down before a result returns, or if the +/// function to be called cannot be negotiated with the remote. +class ResponseAbandoned : public ErrorInfo<ResponseAbandoned> { +public: +  static char ID; + +  std::error_code convertToErrorCode() const override; +  void log(raw_ostream &OS) const override; +}; + +/// This error is returned if the remote does not have a handler installed for +/// the given RPC function. +class CouldNotNegotiate : public ErrorInfo<CouldNotNegotiate> { +public: +  static char ID; + +  CouldNotNegotiate(std::string Signature); +  std::error_code convertToErrorCode() const override; +  void log(raw_ostream &OS) const override; +  const std::string &getSignature() const { return Signature; } +private: +  std::string Signature; +}; +  template <typename DerivedFunc, typename FnT> class Function;  // RPC Function class. @@ -82,16 +170,6 @@ std::mutex Function<DerivedFunc, RetT(ArgTs...)>::NameMutex;  template <typename DerivedFunc, typename RetT, typename... ArgTs>  std::string Function<DerivedFunc, RetT(ArgTs...)>::Name; -/// Provides a typedef for a tuple containing the decayed argument types. -template <typename T> class FunctionArgsTuple; - -template <typename RetT, typename... ArgTs> -class FunctionArgsTuple<RetT(ArgTs...)> { -public: -  using Type = std::tuple<typename std::decay< -      typename std::remove_reference<ArgTs>::type>::type...>; -}; -  /// Allocates RPC function ids during autonegotiation.  /// Specializations of this class must provide four members:  /// @@ -196,6 +274,16 @@ public:  #endif // _MSC_VER +/// Provides a typedef for a tuple containing the decayed argument types. +template <typename T> class FunctionArgsTuple; + +template <typename RetT, typename... ArgTs> +class FunctionArgsTuple<RetT(ArgTs...)> { +public: +  using Type = std::tuple<typename std::decay< +      typename std::remove_reference<ArgTs>::type>::type...>; +}; +  // ResultTraits provides typedefs and utilities specific to the return type  // of functions.  template <typename RetT> class ResultTraits { @@ -274,43 +362,132 @@ template <> class ResultTraits<Error> : public ResultTraits<void> {};  template <typename RetT>  class ResultTraits<Expected<RetT>> : public ResultTraits<RetT> {}; +// Determines whether an RPC function's defined error return type supports +// error return value. +template <typename T> +class SupportsErrorReturn { +public: +  static const bool value = false; +}; + +template <> +class SupportsErrorReturn<Error> { +public: +  static const bool value = true; +}; + +template <typename T> +class SupportsErrorReturn<Expected<T>> { +public: +  static const bool value = true; +}; + +// RespondHelper packages return values based on whether or not the declared +// RPC function return type supports error returns. +template <bool FuncSupportsErrorReturn> +class RespondHelper; + +// RespondHelper specialization for functions that support error returns. +template <> +class RespondHelper<true> { +public: + +  // Send Expected<T>. +  template <typename WireRetT, typename HandlerRetT, typename ChannelT, +            typename FunctionIdT, typename SequenceNumberT> +  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, +                          SequenceNumberT SeqNo, +                          Expected<HandlerRetT> ResultOrErr) { +    if (!ResultOrErr && ResultOrErr.template errorIsA<RPCFatalError>()) +      return ResultOrErr.takeError(); + +    // Open the response message. +    if (auto Err = C.startSendMessage(ResponseId, SeqNo)) +      return Err; + +    // Serialize the result. +    if (auto Err = +        SerializationTraits<ChannelT, WireRetT, +                            Expected<HandlerRetT>>::serialize( +                                                     C, std::move(ResultOrErr))) +      return Err; + +    // Close the response message. +    return C.endSendMessage(); +  } + +  template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> +  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, +                          SequenceNumberT SeqNo, Error Err) { +    if (Err && Err.isA<RPCFatalError>()) +      return Err; +    if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) +      return Err2; +    if (auto Err2 = serializeSeq(C, std::move(Err))) +      return Err2; +    return C.endSendMessage(); +  } + +}; + +// RespondHelper specialization for functions that do not support error returns. +template <> +class RespondHelper<false> { +public: + +  template <typename WireRetT, typename HandlerRetT, typename ChannelT, +            typename FunctionIdT, typename SequenceNumberT> +  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, +                          SequenceNumberT SeqNo, +                          Expected<HandlerRetT> ResultOrErr) { +    if (auto Err = ResultOrErr.takeError()) +      return Err; + +    // Open the response message. +    if (auto Err = C.startSendMessage(ResponseId, SeqNo)) +      return Err; + +    // Serialize the result. +    if (auto Err = +        SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( +                                                               C, *ResultOrErr)) +      return Err; + +    // Close the response message. +    return C.endSendMessage(); +  } + +  template <typename ChannelT, typename FunctionIdT, typename SequenceNumberT> +  static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, +                          SequenceNumberT SeqNo, Error Err) { +    if (Err) +      return Err; +    if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) +      return Err2; +    return C.endSendMessage(); +  } + +}; + +  // Send a response of the given wire return type (WireRetT) over the  // channel, with the given sequence number.  template <typename WireRetT, typename HandlerRetT, typename ChannelT,            typename FunctionIdT, typename SequenceNumberT> -static Error respond(ChannelT &C, const FunctionIdT &ResponseId, -                     SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) { -  // If this was an error bail out. -  // FIXME: Send an "error" message to the client if this is not a channel -  //        failure? -  if (auto Err = ResultOrErr.takeError()) -    return Err; - -  // Open the response message. -  if (auto Err = C.startSendMessage(ResponseId, SeqNo)) -    return Err; - -  // Serialize the result. -  if (auto Err = -          SerializationTraits<ChannelT, WireRetT, HandlerRetT>::serialize( -              C, *ResultOrErr)) -    return Err; - -  // Close the response message. -  return C.endSendMessage(); +Error respond(ChannelT &C, const FunctionIdT &ResponseId, +              SequenceNumberT SeqNo, Expected<HandlerRetT> ResultOrErr) { +  return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: +    template sendResult<WireRetT>(C, ResponseId, SeqNo, std::move(ResultOrErr));  }  // Send an empty response message on the given channel to indicate that  // the handler ran.  template <typename WireRetT, typename ChannelT, typename FunctionIdT,            typename SequenceNumberT> -static Error respond(ChannelT &C, const FunctionIdT &ResponseId, -                     SequenceNumberT SeqNo, Error Err) { -  if (Err) -    return Err; -  if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) -    return Err2; -  return C.endSendMessage(); +Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, +              Error Err) { +  return RespondHelper<SupportsErrorReturn<WireRetT>::value>:: +    sendResult(C, ResponseId, SeqNo, std::move(Err));  }  // Converts a given type to the equivalent error return type. @@ -339,6 +516,29 @@ public:    using Type = Error;  }; +// Traits class that strips the response function from the list of handler +// arguments. +template <typename FnT> class AsyncHandlerTraits; + +template <typename ResultT, typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Expected<ResultT>)>, ArgTs...)> { +public: +  using Type = Error(ArgTs...); +  using ResultType = Expected<ResultT>; +}; + +template <typename... ArgTs> +class AsyncHandlerTraits<Error(std::function<Error(Error)>, ArgTs...)> { +public: +  using Type = Error(ArgTs...); +  using ResultType = Error; +}; + +template <typename ResponseHandlerT, typename... ArgTs> +class AsyncHandlerTraits<Error(ResponseHandlerT, ArgTs...)> : +    public AsyncHandlerTraits<Error(typename std::decay<ResponseHandlerT>::type, +                                    ArgTs...)> {}; +  // This template class provides utilities related to RPC function handlers.  // The base case applies to non-function types (the template class is  // specialized for function types) and inherits from the appropriate @@ -358,15 +558,20 @@ public:    // Return type of the handler.    using ReturnType = RetT; -  // A std::tuple wrapping the handler arguments. -  using ArgStorage = typename FunctionArgsTuple<RetT(ArgTs...)>::Type; -    // Call the given handler with the given arguments. -  template <typename HandlerT> +  template <typename HandlerT, typename... TArgTs>    static typename WrappedHandlerReturn<RetT>::Type -  unpackAndRun(HandlerT &Handler, ArgStorage &Args) { +  unpackAndRun(HandlerT &Handler, std::tuple<TArgTs...> &Args) {      return unpackAndRunHelper(Handler, Args, -                              llvm::index_sequence_for<ArgTs...>()); +                              llvm::index_sequence_for<TArgTs...>()); +  } + +  // Call the given handler with the given arguments. +  template <typename HandlerT, typename ResponderT, typename... TArgTs> +  static Error unpackAndRunAsync(HandlerT &Handler, ResponderT &Responder, +                                 std::tuple<TArgTs...> &Args) { +    return unpackAndRunAsyncHelper(Handler, Responder, Args, +                                   llvm::index_sequence_for<TArgTs...>());    }    // Call the given handler with the given arguments. @@ -379,11 +584,11 @@ public:      return Error::success();    } -  template <typename HandlerT> +  template <typename HandlerT, typename... TArgTs>    static typename std::enable_if<        !std::is_void<typename HandlerTraits<HandlerT>::ReturnType>::value,        typename HandlerTraits<HandlerT>::ReturnType>::type -  run(HandlerT &Handler, ArgTs... Args) { +  run(HandlerT &Handler, TArgTs... Args) {      return Handler(std::move(Args)...);    } @@ -408,15 +613,31 @@ private:          C, std::get<Indexes>(Args)...);    } -  template <typename HandlerT, size_t... Indexes> +  template <typename HandlerT, typename ArgTuple, size_t... Indexes>    static typename WrappedHandlerReturn<        typename HandlerTraits<HandlerT>::ReturnType>::Type -  unpackAndRunHelper(HandlerT &Handler, ArgStorage &Args, +  unpackAndRunHelper(HandlerT &Handler, ArgTuple &Args,                       llvm::index_sequence<Indexes...>) {      return run(Handler, std::move(std::get<Indexes>(Args))...);    } + + +  template <typename HandlerT, typename ResponderT, typename ArgTuple, +            size_t... Indexes> +  static typename WrappedHandlerReturn< +      typename HandlerTraits<HandlerT>::ReturnType>::Type +  unpackAndRunAsyncHelper(HandlerT &Handler, ResponderT &Responder, +                          ArgTuple &Args, +                          llvm::index_sequence<Indexes...>) { +    return run(Handler, Responder, std::move(std::get<Indexes>(Args))...); +  }  }; +// Handler traits for free functions. +template <typename RetT, typename... ArgTs> +class HandlerTraits<RetT(*)(ArgTs...)> +  : public HandlerTraits<RetT(ArgTs...)> {}; +  // Handler traits for class methods (especially call operators for lambdas).  template <typename Class, typename RetT, typename... ArgTs>  class HandlerTraits<RetT (Class::*)(ArgTs...)> @@ -471,7 +692,7 @@ public:    // Create an error instance representing an abandoned response.    static Error createAbandonedResponseError() { -    return orcError(OrcErrorCode::RPCResponseAbandoned); +    return make_error<ResponseAbandoned>();    }  }; @@ -493,7 +714,7 @@ public:        return Err;      if (auto Err = C.endReceiveMessage())        return Err; -    return Handler(Result); +    return Handler(std::move(Result));    }    // Abandon this response by calling the handler with an 'abandoned response' @@ -538,6 +759,72 @@ private:    HandlerT Handler;  }; +template <typename ChannelT, typename FuncRetT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, Expected<FuncRetT>, HandlerT> +    : public ResponseHandler<ChannelT> { +public: +  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + +  // Handle the result by deserializing it from the channel then passing it +  // to the user defined handler. +  Error handleResponse(ChannelT &C) override { +    using HandlerArgType = typename ResponseHandlerArg< +        typename HandlerTraits<HandlerT>::Type>::ArgType; +    HandlerArgType Result((typename HandlerArgType::value_type())); + +    if (auto Err = +            SerializationTraits<ChannelT, Expected<FuncRetT>, +                                HandlerArgType>::deserialize(C, Result)) +      return Err; +    if (auto Err = C.endReceiveMessage()) +      return Err; +    return Handler(std::move(Result)); +  } + +  // Abandon this response by calling the handler with an 'abandoned response' +  // error. +  void abandon() override { +    if (auto Err = Handler(this->createAbandonedResponseError())) { +      // Handlers should not fail when passed an abandoned response error. +      report_fatal_error(std::move(Err)); +    } +  } + +private: +  HandlerT Handler; +}; + +template <typename ChannelT, typename HandlerT> +class ResponseHandlerImpl<ChannelT, Error, HandlerT> +    : public ResponseHandler<ChannelT> { +public: +  ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + +  // Handle the result by deserializing it from the channel then passing it +  // to the user defined handler. +  Error handleResponse(ChannelT &C) override { +    Error Result = Error::success(); +    if (auto Err = +            SerializationTraits<ChannelT, Error, Error>::deserialize(C, Result)) +      return Err; +    if (auto Err = C.endReceiveMessage()) +      return Err; +    return Handler(std::move(Result)); +  } + +  // Abandon this response by calling the handler with an 'abandoned response' +  // error. +  void abandon() override { +    if (auto Err = Handler(this->createAbandonedResponseError())) { +      // Handlers should not fail when passed an abandoned response error. +      report_fatal_error(std::move(Err)); +    } +  } + +private: +  HandlerT Handler; +}; +  // Create a ResponseHandler from a given user handler.  template <typename ChannelT, typename FuncRetT, typename HandlerT>  std::unique_ptr<ResponseHandler<ChannelT>> createResponseHandler(HandlerT H) { @@ -758,8 +1045,13 @@ public:      auto NegotiateId = FnIdAllocator.getNegotiateId();      RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId;      Handlers[NegotiateId] = wrapHandler<OrcRPCNegotiate>( -        [this](const std::string &Name) { return handleNegotiate(Name); }, -        LaunchPolicy()); +        [this](const std::string &Name) { return handleNegotiate(Name); }); +  } + + +  /// Negotiate a function id for Func with the other end of the channel. +  template <typename Func> Error negotiateFunction(bool Retry = false) { +    return getRemoteFunctionId<Func>(true, Retry).takeError();    }    /// Append a call Func, does not call send on the channel. @@ -777,14 +1069,12 @@ public:      // Look up the function ID.      FunctionIdT FnId; -    if (auto FnIdOrErr = getRemoteFunctionId<Func>()) +    if (auto FnIdOrErr = getRemoteFunctionId<Func>(LazyAutoNegotiation, false))        FnId = *FnIdOrErr;      else { -      // This isn't a channel error so we don't want to abandon other pending -      // responses, but we still need to run the user handler with an error to -      // let them know the call failed. -      if (auto Err = Handler(orcError(OrcErrorCode::UnknownRPCFunction))) -        report_fatal_error(std::move(Err)); +      // Negotiation failed. Notify the handler then return the negotiate-failed +      // error. +      cantFail(Handler(make_error<ResponseAbandoned>()));        return FnIdOrErr.takeError();      } @@ -807,20 +1097,20 @@ public:      // Open the function call message.      if (auto Err = C.startSendMessage(FnId, SeqNo)) {        abandonPendingResponses(); -      return joinErrors(std::move(Err), C.endSendMessage()); +      return Err;      }      // Serialize the call arguments.      if (auto Err = detail::HandlerTraits<typename Func::Type>::serializeArgs(              C, Args...)) {        abandonPendingResponses(); -      return joinErrors(std::move(Err), C.endSendMessage()); +      return Err;      }      // Close the function call messagee.      if (auto Err = C.endSendMessage()) {        abandonPendingResponses(); -      return std::move(Err); +      return Err;      }      return Error::success(); @@ -839,8 +1129,10 @@ public:    Error handleOne() {      FunctionIdT FnId;      SequenceNumberT SeqNo; -    if (auto Err = C.startReceiveMessage(FnId, SeqNo)) +    if (auto Err = C.startReceiveMessage(FnId, SeqNo)) { +      abandonPendingResponses();        return Err; +    }      if (FnId == ResponseId)        return handleResponse(SeqNo);      auto I = Handlers.find(FnId); @@ -848,7 +1140,8 @@ public:        return I->second(C, SeqNo);      // else: No handler found. Report error to client? -    return orcError(OrcErrorCode::UnexpectedRPCCall); +    return make_error<BadFunctionCall<FunctionIdT, SequenceNumberT>>(FnId, +                                                                     SeqNo);    }    /// Helper for handling setter procedures - this method returns a functor that @@ -887,10 +1180,25 @@ public:      SequenceNumberMgr.reset();    } +  /// Remove the handler for the given function. +  /// A handler must currently be registered for this function. +  template <typename Func> +  void removeHandler() { +    auto IdItr = LocalFunctionIds.find(Func::getPrototype()); +    assert(IdItr != LocalFunctionIds.end() && +           "Function does not have a registered handler"); +    auto HandlerItr = Handlers.find(IdItr->second); +    assert(HandlerItr != Handlers.end() && +           "Function does not have a registered handler"); +    Handlers.erase(HandlerItr); +  } + +  /// Clear all handlers. +  void clearHandlers() { +    Handlers.clear(); +  } +  protected: -  // The LaunchPolicy type allows a launch policy to be specified when adding -  // a function handler. See addHandlerImpl. -  using LaunchPolicy = std::function<Error(std::function<Error()>)>;    FunctionIdT getInvalidFunctionId() const {      return FnIdAllocator.getInvalidId(); @@ -899,7 +1207,7 @@ protected:    /// Add the given handler to the handler map and make it available for    /// autonegotiation and execution.    template <typename Func, typename HandlerT> -  void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) { +  void addHandlerImpl(HandlerT Handler) {      static_assert(detail::RPCArgTypeCheck<                        CanDeserializeCheck, typename Func::Type, @@ -908,8 +1216,22 @@ protected:      FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>();      LocalFunctionIds[Func::getPrototype()] = NewFnId; -    Handlers[NewFnId] = -        wrapHandler<Func>(std::move(Handler), std::move(Launch)); +    Handlers[NewFnId] = wrapHandler<Func>(std::move(Handler)); +  } + +  template <typename Func, typename HandlerT> +  void addAsyncHandlerImpl(HandlerT Handler) { + +    static_assert(detail::RPCArgTypeCheck< +                      CanDeserializeCheck, typename Func::Type, +                      typename detail::AsyncHandlerTraits< +                        typename detail::HandlerTraits<HandlerT>::Type +                      >::Type>::value, +                  ""); + +    FunctionIdT NewFnId = FnIdAllocator.template allocate<Func>(); +    LocalFunctionIds[Func::getPrototype()] = NewFnId; +    Handlers[NewFnId] = wrapAsyncHandler<Func>(std::move(Handler));    }    Error handleResponse(SequenceNumberT SeqNo) { @@ -929,7 +1251,8 @@ protected:          // Unlock the pending results map to prevent recursive lock.          Lock.unlock();          abandonPendingResponses(); -        return orcError(OrcErrorCode::UnexpectedRPCResponse); +        return make_error< +                 InvalidSequenceNumberForResponse<SequenceNumberT>>(SeqNo);        }      } @@ -951,41 +1274,39 @@ protected:      return I->second;    } -  // Find the remote FunctionId for the given function, which must be in the -  // RemoteFunctionIds map. -  template <typename Func> Expected<FunctionIdT> getRemoteFunctionId() { -    // Try to find the id for the given function. -    auto I = RemoteFunctionIds.find(Func::getPrototype()); +  // Find the remote FunctionId for the given function. +  template <typename Func> +  Expected<FunctionIdT> getRemoteFunctionId(bool NegotiateIfNotInMap, +                                            bool NegotiateIfInvalid) { +    bool DoNegotiate; -    // If we have it in the map, return it. -    if (I != RemoteFunctionIds.end()) -      return I->second; +    // Check if we already have a function id... +    auto I = RemoteFunctionIds.find(Func::getPrototype()); +    if (I != RemoteFunctionIds.end()) { +      // If it's valid there's nothing left to do. +      if (I->second != getInvalidFunctionId()) +        return I->second; +      DoNegotiate = NegotiateIfInvalid; +    } else +      DoNegotiate = NegotiateIfNotInMap; -    // Otherwise, if we have auto-negotiation enabled, try to negotiate it. -    if (LazyAutoNegotiation) { +    // We don't have a function id for Func yet, but we're allowed to try to +    // negotiate one. +    if (DoNegotiate) {        auto &Impl = static_cast<ImplT &>(*this);        if (auto RemoteIdOrErr = -              Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { -        auto &RemoteId = *RemoteIdOrErr; - -        // If autonegotiation indicates that the remote end doesn't support this -        // function, return an unknown function error. -        if (RemoteId == getInvalidFunctionId()) -          return orcError(OrcErrorCode::UnknownRPCFunction); - -        // Autonegotiation succeeded and returned a valid id. Update the map and -        // return the id. -        RemoteFunctionIds[Func::getPrototype()] = RemoteId; -        return RemoteId; -      } else { -        // Autonegotiation failed. Return the error. +          Impl.template callB<OrcRPCNegotiate>(Func::getPrototype())) { +        RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; +        if (*RemoteIdOrErr == getInvalidFunctionId()) +          return make_error<CouldNotNegotiate>(Func::getPrototype()); +        return *RemoteIdOrErr; +      } else          return RemoteIdOrErr.takeError(); -      }      } -    // No key was available in the map and autonegotiation wasn't enabled. -    // Return an unknown function error. -    return orcError(OrcErrorCode::UnknownRPCFunction); +    // No key was available in the map and we weren't allowed to try to +    // negotiate one, so return an unknown function error. +    return make_error<CouldNotNegotiate>(Func::getPrototype());    }    using WrappedHandlerFn = std::function<Error(ChannelT &, SequenceNumberT)>; @@ -993,12 +1314,15 @@ protected:    // Wrap the given user handler in the necessary argument-deserialization code,    // result-serialization code, and call to the launch policy (if present).    template <typename Func, typename HandlerT> -  WrappedHandlerFn wrapHandler(HandlerT Handler, LaunchPolicy Launch) { -    return [this, Handler, Launch](ChannelT &Channel, -                                   SequenceNumberT SeqNo) mutable -> Error { +  WrappedHandlerFn wrapHandler(HandlerT Handler) { +    return [this, Handler](ChannelT &Channel, +                           SequenceNumberT SeqNo) mutable -> Error {        // Start by deserializing the arguments. -      auto Args = std::make_shared< -          typename detail::HandlerTraits<HandlerT>::ArgStorage>(); +      using ArgsTuple = +          typename detail::FunctionArgsTuple< +            typename detail::HandlerTraits<HandlerT>::Type>::Type; +      auto Args = std::make_shared<ArgsTuple>(); +        if (auto Err =                detail::HandlerTraits<typename Func::Type>::deserializeArgs(                    Channel, *Args)) @@ -1013,22 +1337,49 @@ protected:        if (auto Err = Channel.endReceiveMessage())          return Err; -      // Build the handler/responder. -      auto Responder = [this, Handler, Args, &Channel, -                        SeqNo]() mutable -> Error { -        using HTraits = detail::HandlerTraits<HandlerT>; -        using FuncReturn = typename Func::ReturnType; -        return detail::respond<FuncReturn>( -            Channel, ResponseId, SeqNo, HTraits::unpackAndRun(Handler, *Args)); -      }; - -      // If there is an explicit launch policy then use it to launch the -      // handler. -      if (Launch) -        return Launch(std::move(Responder)); - -      // Otherwise run the handler on the listener thread. -      return Responder(); +      using HTraits = detail::HandlerTraits<HandlerT>; +      using FuncReturn = typename Func::ReturnType; +      return detail::respond<FuncReturn>(Channel, ResponseId, SeqNo, +                                         HTraits::unpackAndRun(Handler, *Args)); +    }; +  } + +  // Wrap the given user handler in the necessary argument-deserialization code, +  // result-serialization code, and call to the launch policy (if present). +  template <typename Func, typename HandlerT> +  WrappedHandlerFn wrapAsyncHandler(HandlerT Handler) { +    return [this, Handler](ChannelT &Channel, +                           SequenceNumberT SeqNo) mutable -> Error { +      // Start by deserializing the arguments. +      using AHTraits = detail::AsyncHandlerTraits< +                         typename detail::HandlerTraits<HandlerT>::Type>; +      using ArgsTuple = +          typename detail::FunctionArgsTuple<typename AHTraits::Type>::Type; +      auto Args = std::make_shared<ArgsTuple>(); + +      if (auto Err = +              detail::HandlerTraits<typename Func::Type>::deserializeArgs( +                  Channel, *Args)) +        return Err; + +      // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning +      // for RPCArgs. Void cast RPCArgs to work around this for now. +      // FIXME: Remove this workaround once we can assume a working GCC version. +      (void)Args; + +      // End receieve message, unlocking the channel for reading. +      if (auto Err = Channel.endReceiveMessage()) +        return Err; + +      using HTraits = detail::HandlerTraits<HandlerT>; +      using FuncReturn = typename Func::ReturnType; +      auto Responder = +        [this, SeqNo](typename AHTraits::ResultType RetVal) -> Error { +          return detail::respond<FuncReturn>(C, ResponseId, SeqNo, +                                             std::move(RetVal)); +        }; + +      return HTraits::unpackAndRunAsync(Handler, Responder, *Args);      };    } @@ -1068,66 +1419,31 @@ public:    MultiThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)        : BaseClass(C, LazyAutoNegotiation) {} -  /// The LaunchPolicy type allows a launch policy to be specified when adding -  /// a function handler. See addHandler. -  using LaunchPolicy = typename BaseClass::LaunchPolicy; -    /// Add a handler for the given RPC function.    /// This installs the given handler functor for the given RPC Function, and    /// makes the RPC function available for negotiation/calling from the remote. -  /// -  /// The optional LaunchPolicy argument can be used to control how the handler -  /// is run when called: -  /// -  /// * If no LaunchPolicy is given, the handler code will be run on the RPC -  ///   handler thread that is reading from the channel. This handler cannot -  ///   make blocking RPC calls (since it would be blocking the thread used to -  ///   get the result), but can make non-blocking calls. -  /// -  /// * If a LaunchPolicy is given, the user's handler will be wrapped in a -  ///   call to serialize and send the result, and the resulting functor (with -  ///   type 'Error()' will be passed to the LaunchPolicy. The user can then -  ///   choose to add the wrapped handler to a work queue, spawn a new thread, -  ///   or anything else.    template <typename Func, typename HandlerT> -  void addHandler(HandlerT Handler, LaunchPolicy Launch = LaunchPolicy()) { -    return this->template addHandlerImpl<Func>(std::move(Handler), -                                               std::move(Launch)); +  void addHandler(HandlerT Handler) { +    return this->template addHandlerImpl<Func>(std::move(Handler));    }    /// Add a class-method as a handler.    template <typename Func, typename ClassT, typename RetT, typename... ArgTs> -  void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...), -                  LaunchPolicy Launch = LaunchPolicy()) { +  void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) {      addHandler<Func>( -      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method), -      Launch); +      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));    } -  /// Negotiate a function id for Func with the other end of the channel. -  template <typename Func> Error negotiateFunction(bool Retry = false) { -    using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; - -    // Check if we already have a function id... -    auto I = this->RemoteFunctionIds.find(Func::getPrototype()); -    if (I != this->RemoteFunctionIds.end()) { -      // If it's valid there's nothing left to do. -      if (I->second != this->getInvalidFunctionId()) -        return Error::success(); -      // If it's invalid and we can't re-attempt negotiation, throw an error. -      if (!Retry) -        return orcError(OrcErrorCode::UnknownRPCFunction); -    } +  template <typename Func, typename HandlerT> +  void addAsyncHandler(HandlerT Handler) { +    return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); +  } -    // We don't have a function id for Func yet, call the remote to try to -    // negotiate one. -    if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) { -      this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; -      if (*RemoteIdOrErr == this->getInvalidFunctionId()) -        return orcError(OrcErrorCode::UnknownRPCFunction); -      return Error::success(); -    } else -      return RemoteIdOrErr.takeError(); +  /// Add a class-method as a handler. +  template <typename Func, typename ClassT, typename RetT, typename... ArgTs> +  void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { +    addAsyncHandler<Func>( +      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));    }    /// Return type for non-blocking call primitives. @@ -1159,7 +1475,6 @@ public:                return Error::success();              },              Args...)) { -      this->abandonPendingResponses();        RTraits::consumeAbandoned(FutureResult.get());        return std::move(Err);      } @@ -1191,15 +1506,9 @@ public:              typename AltRetT = typename Func::ReturnType>    typename detail::ResultTraits<AltRetT>::ErrorReturnType    callB(const ArgTs &... Args) { -    if (auto FutureResOrErr = callNB<Func>(Args...)) { -      if (auto Err = this->C.send()) { -        this->abandonPendingResponses(); -        detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned( -            std::move(FutureResOrErr->get())); -        return std::move(Err); -      } +    if (auto FutureResOrErr = callNB<Func>(Args...))        return FutureResOrErr->get(); -    } else +    else        return FutureResOrErr.takeError();    } @@ -1224,16 +1533,13 @@ private:          SingleThreadedRPCEndpoint<ChannelT, FunctionIdT, SequenceNumberT>,          ChannelT, FunctionIdT, SequenceNumberT>; -  using LaunchPolicy = typename BaseClass::LaunchPolicy; -  public:    SingleThreadedRPCEndpoint(ChannelT &C, bool LazyAutoNegotiation)        : BaseClass(C, LazyAutoNegotiation) {}    template <typename Func, typename HandlerT>    void addHandler(HandlerT Handler) { -    return this->template addHandlerImpl<Func>(std::move(Handler), -                                               LaunchPolicy()); +    return this->template addHandlerImpl<Func>(std::move(Handler));    }    template <typename Func, typename ClassT, typename RetT, typename... ArgTs> @@ -1242,30 +1548,16 @@ public:          detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));    } -  /// Negotiate a function id for Func with the other end of the channel. -  template <typename Func> Error negotiateFunction(bool Retry = false) { -    using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; - -    // Check if we already have a function id... -    auto I = this->RemoteFunctionIds.find(Func::getPrototype()); -    if (I != this->RemoteFunctionIds.end()) { -      // If it's valid there's nothing left to do. -      if (I->second != this->getInvalidFunctionId()) -        return Error::success(); -      // If it's invalid and we can't re-attempt negotiation, throw an error. -      if (!Retry) -        return orcError(OrcErrorCode::UnknownRPCFunction); -    } +  template <typename Func, typename HandlerT> +  void addAsyncHandler(HandlerT Handler) { +    return this->template addAsyncHandlerImpl<Func>(std::move(Handler)); +  } -    // We don't have a function id for Func yet, call the remote to try to -    // negotiate one. -    if (auto RemoteIdOrErr = callB<OrcRPCNegotiate>(Func::getPrototype())) { -      this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; -      if (*RemoteIdOrErr == this->getInvalidFunctionId()) -        return orcError(OrcErrorCode::UnknownRPCFunction); -      return Error::success(); -    } else -      return RemoteIdOrErr.takeError(); +  /// Add a class-method as a handler. +  template <typename Func, typename ClassT, typename RetT, typename... ArgTs> +  void addAsyncHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { +    addAsyncHandler<Func>( +      detail::MemberFnWrapper<ClassT, RetT, ArgTs...>(Object, Method));    }    template <typename Func, typename... ArgTs, @@ -1287,7 +1579,6 @@ public:                return Error::success();              },              Args...)) { -      this->abandonPendingResponses();        detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(            std::move(Result));        return std::move(Err); @@ -1295,7 +1586,6 @@ public:      while (!ReceivedResponse) {        if (auto Err = this->handleOne()) { -        this->abandonPendingResponses();          detail::ResultTraits<typename Func::ReturnType>::consumeAbandoned(              std::move(Result));          return std::move(Err); @@ -1306,24 +1596,40 @@ public:    }  }; +/// Asynchronous dispatch for a function on an RPC endpoint. +template <typename RPCClass, typename Func> +class RPCAsyncDispatch { +public: +  RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {} + +  template <typename HandlerT, typename... ArgTs> +  Error operator()(HandlerT Handler, const ArgTs &... Args) const { +    return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...); +  } + +private: +  RPCClass &Endpoint; +}; + +/// Construct an asynchronous dispatcher from an RPC endpoint and a Func. +template <typename Func, typename RPCEndpointT> +RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) { +  return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint); +} +  /// \brief Allows a set of asynchrounous calls to be dispatched, and then  ///        waited on as a group. -template <typename RPCClass> class ParallelCallGroup { +class ParallelCallGroup {  public: -  /// \brief Construct a parallel call group for the given RPC. -  ParallelCallGroup(RPCClass &RPC) : RPC(RPC), NumOutstandingCalls(0) {} - +  ParallelCallGroup() = default;    ParallelCallGroup(const ParallelCallGroup &) = delete;    ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;    /// \brief Make as asynchronous call. -  /// -  /// Does not issue a send call to the RPC's channel. The channel may use this -  /// to batch up subsequent calls. A send will automatically be sent when wait -  /// is called. -  template <typename Func, typename HandlerT, typename... ArgTs> -  Error appendCall(HandlerT Handler, const ArgTs &... Args) { +  template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs> +  Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler, +             const ArgTs &... Args) {      // Increment the count of outstanding calls. This has to happen before      // we invoke the call, as the handler may (depending on scheduling)      // be run immediately on another thread, and we don't want the decrement @@ -1346,38 +1652,21 @@ public:        return Err;      }; -    return RPC.template appendCallAsync<Func>(std::move(WrappedHandler), -                                              Args...); -  } - -  /// \brief Make an asynchronous call. -  /// -  /// The same as appendCall, but also calls send on the channel immediately. -  /// Prefer appendCall if you are about to issue a "wait" call shortly, as -  /// this may allow the channel to better batch the calls. -  template <typename Func, typename HandlerT, typename... ArgTs> -  Error call(HandlerT Handler, const ArgTs &... Args) { -    if (auto Err = appendCall(std::move(Handler), Args...)) -      return Err; -    return RPC.sendAppendedCalls(); +    return AsyncDispatch(std::move(WrappedHandler), Args...);    }    /// \brief Blocks until all calls have been completed and their return value    ///        handlers run. -  Error wait() { -    if (auto Err = RPC.sendAppendedCalls()) -      return Err; +  void wait() {      std::unique_lock<std::mutex> Lock(M);      while (NumOutstandingCalls > 0)        CV.wait(Lock); -    return Error::success();    }  private: -  RPCClass &RPC;    std::mutex M;    std::condition_variable CV; -  uint32_t NumOutstandingCalls; +  uint32_t NumOutstandingCalls = 0;  };  /// @brief Convenience class for grouping RPC Functions into APIs that can be diff --git a/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h b/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h index 0588d2228598..babcc7f26aab 100644 --- a/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h +++ b/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h @@ -1,4 +1,4 @@ -//===- ObjectLinkingLayer.h - Add object files to a JIT process -*- C++ -*-===// +//===-- RTDyldObjectLinkingLayer.h - RTDyld-based jit linking  --*- C++ -*-===//  //  //                     The LLVM Compiler Infrastructure  // @@ -7,12 +7,12 @@  //  //===----------------------------------------------------------------------===//  // -// Contains the definition for the object layer of the JIT. +// Contains the definition for an RTDyld-based, in-process object linking layer.  //  //===----------------------------------------------------------------------===// -#ifndef LLVM_EXECUTIONENGINE_ORC_OBJECTLINKINGLAYER_H -#define LLVM_EXECUTIONENGINE_ORC_OBJECTLINKINGLAYER_H +#ifndef LLVM_EXECUTIONENGINE_ORC_RTDYLDOBJECTLINKINGLAYER_H +#define LLVM_EXECUTIONENGINE_ORC_RTDYLDOBJECTLINKINGLAYER_H  #include "llvm/ADT/STLExtras.h"  #include "llvm/ADT/StringMap.h" @@ -35,7 +35,7 @@  namespace llvm {  namespace orc { -class ObjectLinkingLayerBase { +class RTDyldObjectLinkingLayerBase {  protected:    /// @brief Holds a set of objects to be allocated/linked as a unit in the JIT.    /// @@ -87,7 +87,7 @@ public:  class DoNothingOnNotifyLoaded {  public:    template <typename ObjSetT, typename LoadResult> -  void operator()(ObjectLinkingLayerBase::ObjSetHandleT, const ObjSetT &, +  void operator()(RTDyldObjectLinkingLayerBase::ObjSetHandleT, const ObjSetT &,                    const LoadResult &) {}  }; @@ -98,7 +98,7 @@ public:  /// symbols queried. All objects added to this layer can see each other's  /// symbols.  template <typename NotifyLoadedFtor = DoNothingOnNotifyLoaded> -class ObjectLinkingLayer : public ObjectLinkingLayerBase { +class RTDyldObjectLinkingLayer : public RTDyldObjectLinkingLayerBase {  public:    /// @brief Functor for receiving finalization notifications.    typedef std::function<void(ObjSetHandleT)> NotifyFinalizedFtor; @@ -227,7 +227,7 @@ public:    /// @brief Construct an ObjectLinkingLayer with the given NotifyLoaded,    ///        and NotifyFinalized functors. -  ObjectLinkingLayer( +  RTDyldObjectLinkingLayer(        NotifyLoadedFtor NotifyLoaded = NotifyLoadedFtor(),        NotifyFinalizedFtor NotifyFinalized = NotifyFinalizedFtor())        : NotifyLoaded(std::move(NotifyLoaded)), @@ -359,4 +359,4 @@ private:  } // end namespace orc  } // end namespace llvm -#endif // LLVM_EXECUTIONENGINE_ORC_OBJECTLINKINGLAYER_H +#endif // LLVM_EXECUTIONENGINE_ORC_RTDYLDOBJECTLINKINGLAYER_H diff --git a/include/llvm/ExecutionEngine/Orc/RawByteChannel.h b/include/llvm/ExecutionEngine/Orc/RawByteChannel.h index 3b6c84eb1965..52a546f7c6eb 100644 --- a/include/llvm/ExecutionEngine/Orc/RawByteChannel.h +++ b/include/llvm/ExecutionEngine/Orc/RawByteChannel.h @@ -48,7 +48,11 @@ public:    template <typename FunctionIdT, typename SequenceIdT>    Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {      writeLock.lock(); -    return serializeSeq(*this, FnId, SeqNo); +    if (auto Err = serializeSeq(*this, FnId, SeqNo)) { +      writeLock.unlock(); +      return Err; +    } +    return Error::success();    }    /// Notify the channel that we're ending a message send. @@ -63,7 +67,11 @@ public:    template <typename FunctionIdT, typename SequenceNumberT>    Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {      readLock.lock(); -    return deserializeSeq(*this, FnId, SeqNo); +    if (auto Err = deserializeSeq(*this, FnId, SeqNo)) { +      readLock.unlock(); +      return Err; +    } +    return Error::success();    }    /// Notify the channel that we're ending a message receive. @@ -113,11 +121,19 @@ class SerializationTraits<ChannelT, bool, bool,                                RawByteChannel, ChannelT>::value>::type> {  public:    static Error serialize(ChannelT &C, bool V) { -    return C.appendBytes(reinterpret_cast<const char *>(&V), 1); +    uint8_t Tmp = V ? 1 : 0; +    if (auto Err = +          C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1)) +      return Err; +    return Error::success();    }    static Error deserialize(ChannelT &C, bool &V) { -    return C.readBytes(reinterpret_cast<char *>(&V), 1); +    uint8_t Tmp = 0; +    if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1)) +      return Err; +    V = Tmp != 0; +    return Error::success();    }  }; @@ -134,10 +150,12 @@ public:    }  }; -template <typename ChannelT> -class SerializationTraits<ChannelT, std::string, const char *, -                          typename std::enable_if<std::is_base_of< -                              RawByteChannel, ChannelT>::value>::type> { +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, std::string, T, +                          typename std::enable_if< +                            std::is_base_of<RawByteChannel, ChannelT>::value && +                            (std::is_same<T, const char*>::value || +                             std::is_same<T, char*>::value)>::type> {  public:    static Error serialize(RawByteChannel &C, const char *S) {      return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, | 
