diff options
Diffstat (limited to 'llvm/include/llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h')
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h | 184 |
1 files changed, 184 insertions, 0 deletions
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h b/llvm/include/llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h new file mode 100644 index 000000000000..50e26f8449df --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h @@ -0,0 +1,184 @@ +//===- llvm/ExecutionEngine/Orc/RPC/RawByteChannel.h ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H +#define LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/Error.h" +#include <cstdint> +#include <mutex> +#include <string> +#include <type_traits> + +namespace llvm { +namespace orc { +namespace rpc { + +/// Interface for byte-streams to be used with RPC. +class RawByteChannel { +public: + virtual ~RawByteChannel() = default; + + /// Read Size bytes from the stream into *Dst. + virtual Error readBytes(char *Dst, unsigned Size) = 0; + + /// Read size bytes from *Src and append them to the stream. + virtual Error appendBytes(const char *Src, unsigned Size) = 0; + + /// Flush the stream if possible. + virtual Error send() = 0; + + /// Notify the channel that we're starting a message send. + /// Locks the channel for writing. + template <typename FunctionIdT, typename SequenceIdT> + Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { + writeLock.lock(); + if (auto Err = serializeSeq(*this, FnId, SeqNo)) { + writeLock.unlock(); + return Err; + } + return Error::success(); + } + + /// Notify the channel that we're ending a message send. + /// Unlocks the channel for writing. + Error endSendMessage() { + writeLock.unlock(); + return Error::success(); + } + + /// Notify the channel that we're starting a message receive. + /// Locks the channel for reading. + template <typename FunctionIdT, typename SequenceNumberT> + Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { + readLock.lock(); + if (auto Err = deserializeSeq(*this, FnId, SeqNo)) { + readLock.unlock(); + return Err; + } + return Error::success(); + } + + /// Notify the channel that we're ending a message receive. + /// Unlocks the channel for reading. + Error endReceiveMessage() { + readLock.unlock(); + return Error::success(); + } + + /// Get the lock for stream reading. + std::mutex &getReadLock() { return readLock; } + + /// Get the lock for stream writing. + std::mutex &getWriteLock() { return writeLock; } + +private: + std::mutex readLock, writeLock; +}; + +template <typename ChannelT, typename T> +class SerializationTraits< + ChannelT, T, T, + typename std::enable_if< + std::is_base_of<RawByteChannel, ChannelT>::value && + (std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value || + std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value || + std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value || + std::is_same<T, uint64_t>::value || std::is_same<T, int64_t>::value || + std::is_same<T, char>::value)>::type> { +public: + static Error serialize(ChannelT &C, T V) { + support::endian::byte_swap<T, support::big>(V); + return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); + }; + + static Error deserialize(ChannelT &C, T &V) { + if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) + return Err; + support::endian::byte_swap<T, support::big>(V); + return Error::success(); + }; +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, bool, bool, + typename std::enable_if<std::is_base_of< + RawByteChannel, ChannelT>::value>::type> { +public: + static Error serialize(ChannelT &C, bool V) { + uint8_t Tmp = V ? 1 : 0; + if (auto Err = + C.appendBytes(reinterpret_cast<const char *>(&Tmp), 1)) + return Err; + return Error::success(); + } + + static Error deserialize(ChannelT &C, bool &V) { + uint8_t Tmp = 0; + if (auto Err = C.readBytes(reinterpret_cast<char *>(&Tmp), 1)) + return Err; + V = Tmp != 0; + return Error::success(); + } +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, std::string, StringRef, + typename std::enable_if<std::is_base_of< + RawByteChannel, ChannelT>::value>::type> { +public: + /// RPC channel serialization for std::strings. + static Error serialize(RawByteChannel &C, StringRef S) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) + return Err; + return C.appendBytes((const char *)S.data(), S.size()); + } +}; + +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, std::string, T, + typename std::enable_if< + std::is_base_of<RawByteChannel, ChannelT>::value && + (std::is_same<T, const char*>::value || + std::is_same<T, char*>::value)>::type> { +public: + static Error serialize(RawByteChannel &C, const char *S) { + return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, + S); + } +}; + +template <typename ChannelT> +class SerializationTraits<ChannelT, std::string, std::string, + typename std::enable_if<std::is_base_of< + RawByteChannel, ChannelT>::value>::type> { +public: + /// RPC channel serialization for std::strings. + static Error serialize(RawByteChannel &C, const std::string &S) { + return SerializationTraits<ChannelT, std::string, StringRef>::serialize(C, + S); + } + + /// RPC channel deserialization for std::strings. + static Error deserialize(RawByteChannel &C, std::string &S) { + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + S.resize(Count); + return C.readBytes(&S[0], Count); + } +}; + +} // end namespace rpc +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H |