diff options
Diffstat (limited to 'source/Host/common')
-rw-r--r-- | source/Host/common/MainLoop.cpp | 382 | ||||
-rw-r--r-- | source/Host/common/Socket.cpp | 38 | ||||
-rw-r--r-- | source/Host/common/SocketAddress.cpp | 9 | ||||
-rw-r--r-- | source/Host/common/TCPSocket.cpp | 256 | ||||
-rw-r--r-- | source/Host/common/UDPSocket.cpp | 84 |
5 files changed, 610 insertions, 159 deletions
diff --git a/source/Host/common/MainLoop.cpp b/source/Host/common/MainLoop.cpp new file mode 100644 index 0000000000000..8a9d4f020d5f0 --- /dev/null +++ b/source/Host/common/MainLoop.cpp @@ -0,0 +1,382 @@ +//===-- MainLoop.cpp --------------------------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Config/llvm-config.h" + +#include "lldb/Host/MainLoop.h" +#include "lldb/Utility/Error.h" +#include <algorithm> +#include <cassert> +#include <cerrno> +#include <csignal> +#include <vector> +#include <time.h> + +#if HAVE_SYS_EVENT_H +#include <sys/event.h> +#elif defined(LLVM_ON_WIN32) +#include <winsock2.h> +#else +#include <poll.h> +#endif + +#ifdef LLVM_ON_WIN32 +#define POLL WSAPoll +#else +#define POLL poll +#endif + +#ifdef __ANDROID__ +#define FORCE_PSELECT +#endif + +#if SIGNAL_POLLING_UNSUPPORTED +#ifdef LLVM_ON_WIN32 +typedef int sigset_t; +typedef int siginfo_t; +#endif + +int ppoll(struct pollfd *fds, size_t nfds, const struct timespec *timeout_ts, + const sigset_t *) { + int timeout = + (timeout_ts == nullptr) + ? -1 + : (timeout_ts->tv_sec * 1000 + timeout_ts->tv_nsec / 1000000); + return POLL(fds, nfds, timeout); +} + +#endif + +using namespace lldb; +using namespace lldb_private; + +static sig_atomic_t g_signal_flags[NSIG]; + +static void SignalHandler(int signo, siginfo_t *info, void *) { + assert(signo < NSIG); + g_signal_flags[signo] = 1; +} + +class MainLoop::RunImpl { +public: + // TODO: Use llvm::Expected<T> + static std::unique_ptr<RunImpl> Create(MainLoop &loop, Error &error); + ~RunImpl(); + + Error Poll(); + + template <typename F> void ForEachReadFD(F &&f); + template <typename F> void ForEachSignal(F &&f); + +private: + MainLoop &loop; + +#if HAVE_SYS_EVENT_H + int queue_id; + std::vector<struct kevent> in_events; + struct kevent out_events[4]; + int num_events = -1; + + RunImpl(MainLoop &loop, int queue_id) : loop(loop), queue_id(queue_id) { + in_events.reserve(loop.m_read_fds.size() + loop.m_signals.size()); + } +#else + std::vector<int> signals; +#ifdef FORCE_PSELECT + fd_set read_fd_set; +#else + std::vector<struct pollfd> read_fds; +#endif + + RunImpl(MainLoop &loop) : loop(loop) { + signals.reserve(loop.m_signals.size()); + } + + sigset_t get_sigmask(); +#endif +}; + +#if HAVE_SYS_EVENT_H +MainLoop::RunImpl::~RunImpl() { + int r = close(queue_id); + assert(r == 0); + (void)r; +} +std::unique_ptr<MainLoop::RunImpl> MainLoop::RunImpl::Create(MainLoop &loop, Error &error) +{ + error.Clear(); + int queue_id = kqueue(); + if(queue_id < 0) { + error = Error(errno, eErrorTypePOSIX); + return nullptr; + } + return std::unique_ptr<RunImpl>(new RunImpl(loop, queue_id)); +} + +Error MainLoop::RunImpl::Poll() { + in_events.resize(loop.m_read_fds.size() + loop.m_signals.size()); + unsigned i = 0; + for (auto &fd : loop.m_read_fds) + EV_SET(&in_events[i++], fd.first, EVFILT_READ, EV_ADD, 0, 0, 0); + + for (const auto &sig : loop.m_signals) + EV_SET(&in_events[i++], sig.first, EVFILT_SIGNAL, EV_ADD, 0, 0, 0); + + num_events = kevent(queue_id, in_events.data(), in_events.size(), out_events, + llvm::array_lengthof(out_events), nullptr); + + if (num_events < 0) + return Error("kevent() failed with error %d\n", num_events); + return Error(); +} + +template <typename F> void MainLoop::RunImpl::ForEachReadFD(F &&f) { + assert(num_events >= 0); + for (int i = 0; i < num_events; ++i) { + f(out_events[i].ident); + if (loop.m_terminate_request) + return; + } +} +template <typename F> void MainLoop::RunImpl::ForEachSignal(F && f) {} +#else +MainLoop::RunImpl::~RunImpl() {} +std::unique_ptr<MainLoop::RunImpl> MainLoop::RunImpl::Create(MainLoop &loop, Error &error) +{ + error.Clear(); + return std::unique_ptr<RunImpl>(new RunImpl(loop)); +} + +sigset_t MainLoop::RunImpl::get_sigmask() { +#if SIGNAL_POLLING_UNSUPPORTED + return 0; +#else + sigset_t sigmask; + int ret = pthread_sigmask(SIG_SETMASK, nullptr, &sigmask); + assert(ret == 0); + (void) ret; + + for (const auto &sig : loop.m_signals) { + signals.push_back(sig.first); + sigdelset(&sigmask, sig.first); + } + return sigmask; +#endif +} + +#ifdef FORCE_PSELECT +Error MainLoop::RunImpl::Poll() { + signals.clear(); + + FD_ZERO(&read_fd_set); + int nfds = 0; + for (const auto &fd : loop.m_read_fds) { + FD_SET(fd.first, &read_fd_set); + nfds = std::max(nfds, fd.first + 1); + } + + sigset_t sigmask = get_sigmask(); + if (pselect(nfds, &read_fd_set, nullptr, nullptr, nullptr, &sigmask) == -1 && + errno != EINTR) + return Error(errno, eErrorTypePOSIX); + + return Error(); +} + +template <typename F> void MainLoop::RunImpl::ForEachReadFD(F &&f) { + for (const auto &fd : loop.m_read_fds) { + if(!FD_ISSET(fd.first, &read_fd_set)) + continue; + + f(fd.first); + if (loop.m_terminate_request) + return; + } +} +#else +Error MainLoop::RunImpl::Poll() { + signals.clear(); + read_fds.clear(); + + sigset_t sigmask = get_sigmask(); + + for (const auto &fd : loop.m_read_fds) { + struct pollfd pfd; + pfd.fd = fd.first; + pfd.events = POLLIN; + pfd.revents = 0; + read_fds.push_back(pfd); + } + + if (ppoll(read_fds.data(), read_fds.size(), nullptr, &sigmask) == -1 && + errno != EINTR) + return Error(errno, eErrorTypePOSIX); + + return Error(); +} + +template <typename F> void MainLoop::RunImpl::ForEachReadFD(F &&f) { + for (const auto &fd : read_fds) { + if ((fd.revents & POLLIN) == 0) + continue; + + f(fd.fd); + if (loop.m_terminate_request) + return; + } +} +#endif + +template <typename F> void MainLoop::RunImpl::ForEachSignal(F &&f) { + for (int sig : signals) { + if (g_signal_flags[sig] == 0) + continue; // No signal + g_signal_flags[sig] = 0; + f(sig); + + if (loop.m_terminate_request) + return; + } +} +#endif + +MainLoop::~MainLoop() { + assert(m_read_fds.size() == 0); + assert(m_signals.size() == 0); +} + +MainLoop::ReadHandleUP +MainLoop::RegisterReadObject(const IOObjectSP &object_sp, + const Callback &callback, Error &error) { +#ifdef LLVM_ON_WIN32 + if (object_sp->GetFdType() != IOObject:: eFDTypeSocket) { + error.SetErrorString("MainLoop: non-socket types unsupported on Windows"); + return nullptr; + } +#endif + if (!object_sp || !object_sp->IsValid()) { + error.SetErrorString("IO object is not valid."); + return nullptr; + } + + const bool inserted = + m_read_fds.insert({object_sp->GetWaitableHandle(), callback}).second; + if (!inserted) { + error.SetErrorStringWithFormat("File descriptor %d already monitored.", + object_sp->GetWaitableHandle()); + return nullptr; + } + + return CreateReadHandle(object_sp); +} + +// We shall block the signal, then install the signal handler. The signal will +// be unblocked in +// the Run() function to check for signal delivery. +MainLoop::SignalHandleUP +MainLoop::RegisterSignal(int signo, const Callback &callback, + Error &error) { +#ifdef SIGNAL_POLLING_UNSUPPORTED + error.SetErrorString("Signal polling is not supported on this platform."); + return nullptr; +#else + if (m_signals.find(signo) != m_signals.end()) { + error.SetErrorStringWithFormat("Signal %d already monitored.", signo); + return nullptr; + } + + SignalInfo info; + info.callback = callback; + struct sigaction new_action; + new_action.sa_sigaction = &SignalHandler; + new_action.sa_flags = SA_SIGINFO; + sigemptyset(&new_action.sa_mask); + sigaddset(&new_action.sa_mask, signo); + + sigset_t old_set; + if (int ret = pthread_sigmask(SIG_BLOCK, &new_action.sa_mask, &old_set)) { + error.SetErrorStringWithFormat("pthread_sigmask failed with error %d\n", + ret); + return nullptr; + } + + info.was_blocked = sigismember(&old_set, signo); + if (sigaction(signo, &new_action, &info.old_action) == -1) { + error.SetErrorToErrno(); + if (!info.was_blocked) + pthread_sigmask(SIG_UNBLOCK, &new_action.sa_mask, nullptr); + return nullptr; + } + + m_signals.insert({signo, info}); + g_signal_flags[signo] = 0; + + return SignalHandleUP(new SignalHandle(*this, signo)); +#endif +} + +void MainLoop::UnregisterReadObject(IOObject::WaitableHandle handle) { + bool erased = m_read_fds.erase(handle); + UNUSED_IF_ASSERT_DISABLED(erased); + assert(erased); +} + +void MainLoop::UnregisterSignal(int signo) { +#if SIGNAL_POLLING_UNSUPPORTED + Error("Signal polling is not supported on this platform."); +#else + // We undo the actions of RegisterSignal on a best-effort basis. + auto it = m_signals.find(signo); + assert(it != m_signals.end()); + + sigaction(signo, &it->second.old_action, nullptr); + + sigset_t set; + sigemptyset(&set); + sigaddset(&set, signo); + pthread_sigmask(it->second.was_blocked ? SIG_BLOCK : SIG_UNBLOCK, &set, + nullptr); + + m_signals.erase(it); +#endif +} + +Error MainLoop::Run() { + m_terminate_request = false; + + Error error; + auto impl = RunImpl::Create(*this, error); + if (!impl) + return error; + + // run until termination or until we run out of things to listen to + while (!m_terminate_request && (!m_read_fds.empty() || !m_signals.empty())) { + + error = impl->Poll(); + if (error.Fail()) + return error; + + impl->ForEachSignal([&](int sig) { + auto it = m_signals.find(sig); + if (it != m_signals.end()) + it->second.callback(*this); // Do the work + }); + if (m_terminate_request) + return Error(); + + impl->ForEachReadFD([&](int fd) { + auto it = m_read_fds.find(fd); + if (it != m_read_fds.end()) + it->second(*this); // Do the work + }); + if (m_terminate_request) + return Error(); + } + return Error(); +} diff --git a/source/Host/common/Socket.cpp b/source/Host/common/Socket.cpp index 2a665ddacb64f..d73b5d0ad0732 100644 --- a/source/Host/common/Socket.cpp +++ b/source/Host/common/Socket.cpp @@ -18,6 +18,8 @@ #include "lldb/Utility/Log.h" #include "lldb/Utility/RegularExpression.h" +#include "llvm/ADT/STLExtras.h" + #ifndef LLDB_DISABLE_POSIX #include "lldb/Host/posix/DomainSocket.h" @@ -67,9 +69,11 @@ bool IsInterrupted() { } } -Socket::Socket(NativeSocket socket, SocketProtocol protocol, bool should_close) +Socket::Socket(SocketProtocol protocol, bool should_close, + bool child_processes_inherit) : IOObject(eFDTypeSocket, should_close), m_protocol(protocol), - m_socket(socket) {} + m_socket(kInvalidSocketValue), + m_child_processes_inherit(child_processes_inherit) {} Socket::~Socket() { Close(); } @@ -81,14 +85,17 @@ std::unique_ptr<Socket> Socket::Create(const SocketProtocol protocol, std::unique_ptr<Socket> socket_up; switch (protocol) { case ProtocolTcp: - socket_up.reset(new TCPSocket(child_processes_inherit, error)); + socket_up = + llvm::make_unique<TCPSocket>(true, child_processes_inherit); break; case ProtocolUdp: - socket_up.reset(new UDPSocket(child_processes_inherit, error)); + socket_up = + llvm::make_unique<UDPSocket>(true, child_processes_inherit); break; case ProtocolUnixDomain: #ifndef LLDB_DISABLE_POSIX - socket_up.reset(new DomainSocket(child_processes_inherit, error)); + socket_up = + llvm::make_unique<DomainSocket>(true, child_processes_inherit); #else error.SetErrorString( "Unix domain sockets are not supported on this platform."); @@ -96,7 +103,8 @@ std::unique_ptr<Socket> Socket::Create(const SocketProtocol protocol, break; case ProtocolUnixAbstract: #ifdef __linux__ - socket_up.reset(new AbstractSocket(child_processes_inherit, error)); + socket_up = + llvm::make_unique<AbstractSocket>(child_processes_inherit); #else error.SetErrorString( "Abstract domain sockets are not supported on this platform."); @@ -145,7 +153,7 @@ Error Socket::TcpListen(llvm::StringRef host_and_port, return error; std::unique_ptr<TCPSocket> listen_socket( - new TCPSocket(child_processes_inherit, error)); + new TCPSocket(true, child_processes_inherit)); if (error.Fail()) return error; @@ -208,7 +216,7 @@ Error Socket::UnixDomainAccept(llvm::StringRef name, if (error.Fail()) return error; - error = listen_socket->Accept(name, child_processes_inherit, socket); + error = listen_socket->Accept(socket); return error; } @@ -240,18 +248,22 @@ Error Socket::UnixAbstractAccept(llvm::StringRef name, if (error.Fail()) return error; - error = listen_socket->Accept(name, child_processes_inherit, socket); + error = listen_socket->Accept(socket); return error; } bool Socket::DecodeHostAndPort(llvm::StringRef host_and_port, std::string &host_str, std::string &port_str, int32_t &port, Error *error_ptr) { - static RegularExpression g_regex(llvm::StringRef("([^:]+):([0-9]+)")); + static RegularExpression g_regex( + llvm::StringRef("([^:]+|\\[[0-9a-fA-F:]+.*\\]):([0-9]+)")); RegularExpression::Match regex_match(2); if (g_regex.Execute(host_and_port, ®ex_match)) { if (regex_match.GetMatchAtIndex(host_and_port.data(), 1, host_str) && regex_match.GetMatchAtIndex(host_and_port.data(), 2, port_str)) { + // IPv6 addresses are wrapped in [] when specified with ports + if (host_str.front() == '[' && host_str.back() == ']') + host_str = host_str.substr(1, host_str.size() - 2); bool ok = false; port = StringConvert::ToUInt32(port_str.c_str(), UINT32_MAX, 10, &ok); if (ok && port <= UINT16_MAX) { @@ -404,12 +416,12 @@ NativeSocket Socket::CreateSocket(const int domain, const int type, const int protocol, bool child_processes_inherit, Error &error) { error.Clear(); - auto socketType = type; + auto socket_type = type; #ifdef SOCK_CLOEXEC if (!child_processes_inherit) - socketType |= SOCK_CLOEXEC; + socket_type |= SOCK_CLOEXEC; #endif - auto sock = ::socket(domain, socketType, protocol); + auto sock = ::socket(domain, socket_type, protocol); if (sock == kInvalidSocketValue) SetLastError(error); diff --git a/source/Host/common/SocketAddress.cpp b/source/Host/common/SocketAddress.cpp index b41cef6ca2ebd..440ae5d9027fd 100644 --- a/source/Host/common/SocketAddress.cpp +++ b/source/Host/common/SocketAddress.cpp @@ -6,6 +6,12 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// +// Note: This file is used on Darwin by debugserver, so it needs to remain as +// self contained as possible, and devoid of references to LLVM unless +// there is compelling reason. +// +//===----------------------------------------------------------------------===// #if defined(_MSC_VER) #define _WINSOCK_DEPRECATED_NO_WARNINGS @@ -227,7 +233,8 @@ bool SocketAddress::getaddrinfo(const char *host, const char *service, int ai_flags) { Clear(); - auto addresses = GetAddressInfo(host, service, ai_family, ai_socktype, ai_protocol, ai_flags); + auto addresses = GetAddressInfo(host, service, ai_family, ai_socktype, + ai_protocol, ai_flags); if (!addresses.empty()) *this = addresses[0]; return IsValid(); diff --git a/source/Host/common/TCPSocket.cpp b/source/Host/common/TCPSocket.cpp index 9a009280a9045..55db4bb0c4562 100644 --- a/source/Host/common/TCPSocket.cpp +++ b/source/Host/common/TCPSocket.cpp @@ -14,30 +14,57 @@ #include "lldb/Host/common/TCPSocket.h" #include "lldb/Host/Config.h" +#include "lldb/Host/MainLoop.h" #include "lldb/Utility/Log.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/Support/raw_ostream.h" + #ifndef LLDB_DISABLE_POSIX #include <arpa/inet.h> #include <netinet/tcp.h> #include <sys/socket.h> #endif +#if defined(LLVM_ON_WIN32) +#include <winsock2.h> +#endif + +#ifdef LLVM_ON_WIN32 +#define CLOSE_SOCKET closesocket +typedef const char *set_socket_option_arg_type; +#else +#define CLOSE_SOCKET ::close +typedef const void *set_socket_option_arg_type; +#endif + using namespace lldb; using namespace lldb_private; namespace { - -const int kDomain = AF_INET; const int kType = SOCK_STREAM; } -TCPSocket::TCPSocket(NativeSocket socket, bool should_close) - : Socket(socket, ProtocolTcp, should_close) {} +TCPSocket::TCPSocket(bool should_close, bool child_processes_inherit) + : Socket(ProtocolTcp, should_close, child_processes_inherit) {} -TCPSocket::TCPSocket(bool child_processes_inherit, Error &error) - : TCPSocket(CreateSocket(kDomain, kType, IPPROTO_TCP, - child_processes_inherit, error), - true) {} +TCPSocket::TCPSocket(NativeSocket socket, const TCPSocket &listen_socket) + : Socket(ProtocolTcp, listen_socket.m_should_close_fd, + listen_socket.m_child_processes_inherit) { + m_socket = socket; +} + +TCPSocket::TCPSocket(NativeSocket socket, bool should_close, + bool child_processes_inherit) + : Socket(ProtocolTcp, should_close, child_processes_inherit) { + m_socket = socket; +} + +TCPSocket::~TCPSocket() { CloseListenSockets(); } + +bool TCPSocket::IsValid() const { + return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0; +} // Return the port number that is being used by the socket. uint16_t TCPSocket::GetLocalPortNumber() const { @@ -46,6 +73,12 @@ uint16_t TCPSocket::GetLocalPortNumber() const { socklen_t sock_addr_len = sock_addr.GetMaxLength(); if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0) return sock_addr.GetPort(); + } else if (!m_listen_sockets.empty()) { + SocketAddress sock_addr; + socklen_t sock_addr_len = sock_addr.GetMaxLength(); + if (::getsockname(m_listen_sockets.begin()->first, sock_addr, + &sock_addr_len) == 0) + return sock_addr.GetPort(); } return 0; } @@ -84,9 +117,18 @@ std::string TCPSocket::GetRemoteIPAddress() const { return ""; } +Error TCPSocket::CreateSocket(int domain) { + Error error; + if (IsValid()) + error = Close(); + if (error.Fail()) + return error; + m_socket = Socket::CreateSocket(domain, kType, IPPROTO_TCP, + m_child_processes_inherit, error); + return error; +} + Error TCPSocket::Connect(llvm::StringRef name) { - if (m_socket == kInvalidSocketValue) - return Error("Invalid socket"); Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_COMMUNICATION)); if (log) @@ -99,146 +141,140 @@ Error TCPSocket::Connect(llvm::StringRef name) { if (!DecodeHostAndPort(name, host_str, port_str, port, &error)) return error; - struct sockaddr_in sa; - ::memset(&sa, 0, sizeof(sa)); - sa.sin_family = kDomain; - sa.sin_port = htons(port); - - int inet_pton_result = ::inet_pton(kDomain, host_str.c_str(), &sa.sin_addr); - - if (inet_pton_result <= 0) { - struct hostent *host_entry = gethostbyname(host_str.c_str()); - if (host_entry) - host_str = ::inet_ntoa(*(struct in_addr *)*host_entry->h_addr_list); - inet_pton_result = ::inet_pton(kDomain, host_str.c_str(), &sa.sin_addr); - if (inet_pton_result <= 0) { - if (inet_pton_result == -1) - SetLastError(error); - else - error.SetErrorStringWithFormat("invalid host string: '%s'", - host_str.c_str()); + auto addresses = lldb_private::SocketAddress::GetAddressInfo( + host_str.c_str(), NULL, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP); + for (auto address : addresses) { + error = CreateSocket(address.GetFamily()); + if (error.Fail()) + continue; - return error; + address.SetPort(port); + + if (-1 == ::connect(GetNativeSocket(), &address.sockaddr(), + address.GetLength())) { + CLOSE_SOCKET(GetNativeSocket()); + continue; } - } - if (-1 == - ::connect(GetNativeSocket(), (const struct sockaddr *)&sa, sizeof(sa))) { - SetLastError(error); + SetOptionNoDelay(); + + error.Clear(); return error; } - // Keep our TCP packets coming without any delays. - SetOptionNoDelay(); - error.Clear(); + error.SetErrorString("Failed to connect port"); return error; } Error TCPSocket::Listen(llvm::StringRef name, int backlog) { - Error error; - - // enable local address reuse - SetOptionReuseAddress(); - Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_CONNECTION)); if (log) log->Printf("TCPSocket::%s (%s)", __FUNCTION__, name.data()); + Error error; std::string host_str; std::string port_str; int32_t port = INT32_MIN; if (!DecodeHostAndPort(name, host_str, port_str, port, &error)) return error; - SocketAddress bind_addr; + if (host_str == "*") + host_str = "0.0.0.0"; + auto addresses = lldb_private::SocketAddress::GetAddressInfo( + host_str.c_str(), NULL, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP); + for (auto address : addresses) { + int fd = Socket::CreateSocket(address.GetFamily(), kType, IPPROTO_TCP, + m_child_processes_inherit, error); + if (error.Fail()) { + error.Clear(); + continue; + } - // Only bind to the loopback address if we are expecting a connection from - // localhost to avoid any firewall issues. - const bool bind_addr_success = (host_str == "127.0.0.1") - ? bind_addr.SetToLocalhost(kDomain, port) - : bind_addr.SetToAnyAddress(kDomain, port); + // enable local address reuse + int option_value = 1; + set_socket_option_arg_type option_value_p = + reinterpret_cast<set_socket_option_arg_type>(&option_value); + ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, option_value_p, + sizeof(option_value)); - if (!bind_addr_success) { - error.SetErrorString("Failed to bind port"); - return error; - } + address.SetPort(port); + + int err = ::bind(fd, &address.sockaddr(), address.GetLength()); + if (-1 != err) + err = ::listen(fd, backlog); - int err = ::bind(GetNativeSocket(), bind_addr, bind_addr.GetLength()); - if (err != -1) - err = ::listen(GetNativeSocket(), backlog); + if (-1 == err) { + CLOSE_SOCKET(fd); + continue; + } - if (err == -1) - SetLastError(error); + if (port == 0) { + socklen_t sa_len = address.GetLength(); + if (getsockname(fd, &address.sockaddr(), &sa_len) == 0) + port = address.GetPort(); + } + m_listen_sockets[fd] = address; + } + if (m_listen_sockets.size() == 0) + error.SetErrorString("Failed to connect port"); return error; } -Error TCPSocket::Accept(llvm::StringRef name, bool child_processes_inherit, - Socket *&conn_socket) { +void TCPSocket::CloseListenSockets() { + for (auto socket : m_listen_sockets) + CLOSE_SOCKET(socket.first); + m_listen_sockets.clear(); +} + +Error TCPSocket::Accept(Socket *&conn_socket) { Error error; - std::string host_str; - std::string port_str; - int32_t port; - if (!DecodeHostAndPort(name, host_str, port_str, port, &error)) + if (m_listen_sockets.size() == 0) { + error.SetErrorString("No open listening sockets!"); return error; + } - const sa_family_t family = kDomain; - const int socktype = kType; - const int protocol = IPPROTO_TCP; - SocketAddress listen_addr; - if (host_str.empty()) - listen_addr.SetToLocalhost(family, port); - else if (host_str.compare("*") == 0) - listen_addr.SetToAnyAddress(family, port); - else { - if (!listen_addr.getaddrinfo(host_str.c_str(), port_str.c_str(), family, - socktype, protocol)) { - error.SetErrorStringWithFormat("unable to resolve hostname '%s'", - host_str.c_str()); + int sock = -1; + int listen_sock = -1; + lldb_private::SocketAddress AcceptAddr; + MainLoop accept_loop; + std::vector<MainLoopBase::ReadHandleUP> handles; + for (auto socket : m_listen_sockets) { + auto fd = socket.first; + auto inherit = this->m_child_processes_inherit; + auto io_sp = IOObjectSP(new TCPSocket(socket.first, false, inherit)); + handles.emplace_back(accept_loop.RegisterReadObject( + io_sp, [fd, inherit, &sock, &AcceptAddr, &error, + &listen_sock](MainLoopBase &loop) { + socklen_t sa_len = AcceptAddr.GetMaxLength(); + sock = AcceptSocket(fd, &AcceptAddr.sockaddr(), &sa_len, inherit, + error); + listen_sock = fd; + loop.RequestTermination(); + }, error)); + if (error.Fail()) return error; - } } bool accept_connection = false; std::unique_ptr<TCPSocket> accepted_socket; - // Loop until we are happy with our connection while (!accept_connection) { - struct sockaddr_in accept_addr; - ::memset(&accept_addr, 0, sizeof accept_addr); -#if !(defined(__linux__) || defined(_WIN32)) - accept_addr.sin_len = sizeof accept_addr; -#endif - socklen_t accept_addr_len = sizeof accept_addr; - - int sock = AcceptSocket(GetNativeSocket(), (struct sockaddr *)&accept_addr, - &accept_addr_len, child_processes_inherit, error); - + accept_loop.Run(); + if (error.Fail()) - break; - - bool is_same_addr = true; -#if !(defined(__linux__) || (defined(_WIN32))) - is_same_addr = (accept_addr_len == listen_addr.sockaddr_in().sin_len); -#endif - if (is_same_addr) - is_same_addr = (accept_addr.sin_addr.s_addr == - listen_addr.sockaddr_in().sin_addr.s_addr); - - if (is_same_addr || - (listen_addr.sockaddr_in().sin_addr.s_addr == INADDR_ANY)) { - accept_connection = true; - accepted_socket.reset(new TCPSocket(sock, true)); - } else { - const uint8_t *accept_ip = (const uint8_t *)&accept_addr.sin_addr.s_addr; - const uint8_t *listen_ip = - (const uint8_t *)&listen_addr.sockaddr_in().sin_addr.s_addr; - ::fprintf(stderr, "error: rejecting incoming connection from %u.%u.%u.%u " - "(expecting %u.%u.%u.%u)\n", - accept_ip[0], accept_ip[1], accept_ip[2], accept_ip[3], - listen_ip[0], listen_ip[1], listen_ip[2], listen_ip[3]); - accepted_socket.reset(); + return error; + + lldb_private::SocketAddress &AddrIn = m_listen_sockets[listen_sock]; + if (!AddrIn.IsAnyAddr() && AcceptAddr != AddrIn) { + CLOSE_SOCKET(sock); + llvm::errs() << llvm::formatv( + "error: rejecting incoming connection from {0} (expecting {1})", + AcceptAddr.GetIPAddress(), AddrIn.GetIPAddress()); + continue; } + accept_connection = true; + accepted_socket.reset(new TCPSocket(sock, *this)); } if (!accepted_socket) diff --git a/source/Host/common/UDPSocket.cpp b/source/Host/common/UDPSocket.cpp index 7ca62e7496ba0..a32657aab0a66 100644 --- a/source/Host/common/UDPSocket.cpp +++ b/source/Host/common/UDPSocket.cpp @@ -28,13 +28,16 @@ const int kDomain = AF_INET; const int kType = SOCK_DGRAM; static const char *g_not_supported_error = "Not supported"; -} +} // namespace -UDPSocket::UDPSocket(NativeSocket socket) : Socket(socket, ProtocolUdp, true) {} +UDPSocket::UDPSocket(bool should_close, bool child_processes_inherit) + : Socket(ProtocolUdp, should_close, child_processes_inherit) {} -UDPSocket::UDPSocket(bool child_processes_inherit, Error &error) - : UDPSocket( - CreateSocket(kDomain, kType, 0, child_processes_inherit, error)) {} +UDPSocket::UDPSocket(NativeSocket socket, const UDPSocket &listen_socket) + : Socket(ProtocolUdp, listen_socket.m_should_close_fd, + listen_socket.m_child_processes_inherit) { + m_socket = socket; +} size_t UDPSocket::Send(const void *buf, const size_t num_bytes) { return ::sendto(m_socket, static_cast<const char *>(buf), num_bytes, 0, @@ -42,27 +45,14 @@ size_t UDPSocket::Send(const void *buf, const size_t num_bytes) { } Error UDPSocket::Connect(llvm::StringRef name) { - return Error("%s", g_not_supported_error); -} - -Error UDPSocket::Listen(llvm::StringRef name, int backlog) { - return Error("%s", g_not_supported_error); -} - -Error UDPSocket::Accept(llvm::StringRef name, bool child_processes_inherit, - Socket *&socket) { - return Error("%s", g_not_supported_error); -} - -Error UDPSocket::Connect(llvm::StringRef name, bool child_processes_inherit, - Socket *&socket) { - std::unique_ptr<UDPSocket> final_socket; - Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_CONNECTION)); if (log) log->Printf("UDPSocket::%s (host/port = %s)", __FUNCTION__, name.data()); Error error; + if (error.Fail()) + return error; + std::string host_str; std::string port_str; int32_t port = INT32_MIN; @@ -94,12 +84,11 @@ Error UDPSocket::Connect(llvm::StringRef name, bool child_processes_inherit, for (struct addrinfo *service_info_ptr = service_info_list; service_info_ptr != nullptr; service_info_ptr = service_info_ptr->ai_next) { - auto send_fd = CreateSocket( + m_socket = Socket::CreateSocket( service_info_ptr->ai_family, service_info_ptr->ai_socktype, - service_info_ptr->ai_protocol, child_processes_inherit, error); + service_info_ptr->ai_protocol, m_child_processes_inherit, error); if (error.Success()) { - final_socket.reset(new UDPSocket(send_fd)); - final_socket->m_sockaddr = service_info_ptr; + m_sockaddr = service_info_ptr; break; } else continue; @@ -107,16 +96,17 @@ Error UDPSocket::Connect(llvm::StringRef name, bool child_processes_inherit, ::freeaddrinfo(service_info_list); - if (!final_socket) + if (IsValid()) return error; SocketAddress bind_addr; // Only bind to the loopback address if we are expecting a connection from // localhost to avoid any firewall issues. - const bool bind_addr_success = (host_str == "127.0.0.1" || host_str == "localhost") - ? bind_addr.SetToLocalhost(kDomain, port) - : bind_addr.SetToAnyAddress(kDomain, port); + const bool bind_addr_success = + (host_str == "127.0.0.1" || host_str == "localhost") + ? bind_addr.SetToLocalhost(kDomain, port) + : bind_addr.SetToAnyAddress(kDomain, port); if (!bind_addr_success) { error.SetErrorString("Failed to get hostspec to bind for"); @@ -125,13 +115,37 @@ Error UDPSocket::Connect(llvm::StringRef name, bool child_processes_inherit, bind_addr.SetPort(0); // Let the source port # be determined dynamically - err = ::bind(final_socket->GetNativeSocket(), bind_addr, bind_addr.GetLength()); - - struct sockaddr_in source_info; - socklen_t address_len = sizeof (struct sockaddr_in); - err = ::getsockname(final_socket->GetNativeSocket(), (struct sockaddr *) &source_info, &address_len); + err = ::bind(m_socket, bind_addr, bind_addr.GetLength()); - socket = final_socket.release(); error.Clear(); return error; } + +Error UDPSocket::Listen(llvm::StringRef name, int backlog) { + return Error("%s", g_not_supported_error); +} + +Error UDPSocket::Accept(Socket *&socket) { + return Error("%s", g_not_supported_error); +} + +Error UDPSocket::CreateSocket() { + Error error; + if (IsValid()) + error = Close(); + if (error.Fail()) + return error; + m_socket = + Socket::CreateSocket(kDomain, kType, 0, m_child_processes_inherit, error); + return error; +} + +Error UDPSocket::Connect(llvm::StringRef name, bool child_processes_inherit, + Socket *&socket) { + std::unique_ptr<UDPSocket> final_socket( + new UDPSocket(true, child_processes_inherit)); + Error error = final_socket->Connect(name); + if (!error.Fail()) + socket = final_socket.release(); + return error; +} |