diff options
Diffstat (limited to 'source/Host/common/TCPSocket.cpp')
| -rw-r--r-- | source/Host/common/TCPSocket.cpp | 256 | 
1 files changed, 146 insertions, 110 deletions
| diff --git a/source/Host/common/TCPSocket.cpp b/source/Host/common/TCPSocket.cpp index 9a009280a904..55db4bb0c456 100644 --- a/source/Host/common/TCPSocket.cpp +++ b/source/Host/common/TCPSocket.cpp @@ -14,30 +14,57 @@  #include "lldb/Host/common/TCPSocket.h"  #include "lldb/Host/Config.h" +#include "lldb/Host/MainLoop.h"  #include "lldb/Utility/Log.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/Support/raw_ostream.h" +  #ifndef LLDB_DISABLE_POSIX  #include <arpa/inet.h>  #include <netinet/tcp.h>  #include <sys/socket.h>  #endif +#if defined(LLVM_ON_WIN32) +#include <winsock2.h> +#endif + +#ifdef LLVM_ON_WIN32 +#define CLOSE_SOCKET closesocket +typedef const char *set_socket_option_arg_type; +#else +#define CLOSE_SOCKET ::close +typedef const void *set_socket_option_arg_type; +#endif +  using namespace lldb;  using namespace lldb_private;  namespace { - -const int kDomain = AF_INET;  const int kType = SOCK_STREAM;  } -TCPSocket::TCPSocket(NativeSocket socket, bool should_close) -    : Socket(socket, ProtocolTcp, should_close) {} +TCPSocket::TCPSocket(bool should_close, bool child_processes_inherit) +    : Socket(ProtocolTcp, should_close, child_processes_inherit) {} -TCPSocket::TCPSocket(bool child_processes_inherit, Error &error) -    : TCPSocket(CreateSocket(kDomain, kType, IPPROTO_TCP, -                             child_processes_inherit, error), -                true) {} +TCPSocket::TCPSocket(NativeSocket socket, const TCPSocket &listen_socket) +    : Socket(ProtocolTcp, listen_socket.m_should_close_fd, +             listen_socket.m_child_processes_inherit) { +  m_socket = socket; +} + +TCPSocket::TCPSocket(NativeSocket socket, bool should_close, +                     bool child_processes_inherit) +    : Socket(ProtocolTcp, should_close, child_processes_inherit) { +  m_socket = socket; +} + +TCPSocket::~TCPSocket() { CloseListenSockets(); } + +bool TCPSocket::IsValid() const { +  return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0; +}  // Return the port number that is being used by the socket.  uint16_t TCPSocket::GetLocalPortNumber() const { @@ -46,6 +73,12 @@ uint16_t TCPSocket::GetLocalPortNumber() const {      socklen_t sock_addr_len = sock_addr.GetMaxLength();      if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)        return sock_addr.GetPort(); +  } else if (!m_listen_sockets.empty()) { +    SocketAddress sock_addr; +    socklen_t sock_addr_len = sock_addr.GetMaxLength(); +    if (::getsockname(m_listen_sockets.begin()->first, sock_addr, +                      &sock_addr_len) == 0) +      return sock_addr.GetPort();    }    return 0;  } @@ -84,9 +117,18 @@ std::string TCPSocket::GetRemoteIPAddress() const {    return "";  } +Error TCPSocket::CreateSocket(int domain) { +  Error error; +  if (IsValid()) +    error = Close(); +  if (error.Fail()) +    return error; +  m_socket = Socket::CreateSocket(domain, kType, IPPROTO_TCP, +                                  m_child_processes_inherit, error); +  return error; +} +  Error TCPSocket::Connect(llvm::StringRef name) { -  if (m_socket == kInvalidSocketValue) -    return Error("Invalid socket");    Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_COMMUNICATION));    if (log) @@ -99,146 +141,140 @@ Error TCPSocket::Connect(llvm::StringRef name) {    if (!DecodeHostAndPort(name, host_str, port_str, port, &error))      return error; -  struct sockaddr_in sa; -  ::memset(&sa, 0, sizeof(sa)); -  sa.sin_family = kDomain; -  sa.sin_port = htons(port); - -  int inet_pton_result = ::inet_pton(kDomain, host_str.c_str(), &sa.sin_addr); - -  if (inet_pton_result <= 0) { -    struct hostent *host_entry = gethostbyname(host_str.c_str()); -    if (host_entry) -      host_str = ::inet_ntoa(*(struct in_addr *)*host_entry->h_addr_list); -    inet_pton_result = ::inet_pton(kDomain, host_str.c_str(), &sa.sin_addr); -    if (inet_pton_result <= 0) { -      if (inet_pton_result == -1) -        SetLastError(error); -      else -        error.SetErrorStringWithFormat("invalid host string: '%s'", -                                       host_str.c_str()); +  auto addresses = lldb_private::SocketAddress::GetAddressInfo( +      host_str.c_str(), NULL, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP); +  for (auto address : addresses) { +    error = CreateSocket(address.GetFamily()); +    if (error.Fail()) +      continue; -      return error; +    address.SetPort(port); + +    if (-1 == ::connect(GetNativeSocket(), &address.sockaddr(), +                        address.GetLength())) { +      CLOSE_SOCKET(GetNativeSocket()); +      continue;      } -  } -  if (-1 == -      ::connect(GetNativeSocket(), (const struct sockaddr *)&sa, sizeof(sa))) { -    SetLastError(error); +    SetOptionNoDelay(); + +    error.Clear();      return error;    } -  // Keep our TCP packets coming without any delays. -  SetOptionNoDelay(); -  error.Clear(); +  error.SetErrorString("Failed to connect port");    return error;  }  Error TCPSocket::Listen(llvm::StringRef name, int backlog) { -  Error error; - -  // enable local address reuse -  SetOptionReuseAddress(); -    Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_CONNECTION));    if (log)      log->Printf("TCPSocket::%s (%s)", __FUNCTION__, name.data()); +  Error error;    std::string host_str;    std::string port_str;    int32_t port = INT32_MIN;    if (!DecodeHostAndPort(name, host_str, port_str, port, &error))      return error; -  SocketAddress bind_addr; +  if (host_str == "*") +    host_str = "0.0.0.0"; +  auto addresses = lldb_private::SocketAddress::GetAddressInfo( +      host_str.c_str(), NULL, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP); +  for (auto address : addresses) { +    int fd = Socket::CreateSocket(address.GetFamily(), kType, IPPROTO_TCP, +                                  m_child_processes_inherit, error); +    if (error.Fail()) { +      error.Clear(); +      continue; +    } -  // Only bind to the loopback address if we are expecting a connection from -  // localhost to avoid any firewall issues. -  const bool bind_addr_success = (host_str == "127.0.0.1") -                                     ? bind_addr.SetToLocalhost(kDomain, port) -                                     : bind_addr.SetToAnyAddress(kDomain, port); +    // enable local address reuse +    int option_value = 1; +    set_socket_option_arg_type option_value_p = +        reinterpret_cast<set_socket_option_arg_type>(&option_value); +    ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, option_value_p, +                 sizeof(option_value)); -  if (!bind_addr_success) { -    error.SetErrorString("Failed to bind port"); -    return error; -  } +    address.SetPort(port); + +    int err = ::bind(fd, &address.sockaddr(), address.GetLength()); +    if (-1 != err) +      err = ::listen(fd, backlog); -  int err = ::bind(GetNativeSocket(), bind_addr, bind_addr.GetLength()); -  if (err != -1) -    err = ::listen(GetNativeSocket(), backlog); +    if (-1 == err) { +      CLOSE_SOCKET(fd); +      continue; +    } -  if (err == -1) -    SetLastError(error); +    if (port == 0) { +      socklen_t sa_len = address.GetLength(); +      if (getsockname(fd, &address.sockaddr(), &sa_len) == 0) +        port = address.GetPort(); +    } +    m_listen_sockets[fd] = address; +  } +  if (m_listen_sockets.size() == 0) +    error.SetErrorString("Failed to connect port");    return error;  } -Error TCPSocket::Accept(llvm::StringRef name, bool child_processes_inherit, -                        Socket *&conn_socket) { +void TCPSocket::CloseListenSockets() { +  for (auto socket : m_listen_sockets) +  CLOSE_SOCKET(socket.first); +  m_listen_sockets.clear(); +} + +Error TCPSocket::Accept(Socket *&conn_socket) {    Error error; -  std::string host_str; -  std::string port_str; -  int32_t port; -  if (!DecodeHostAndPort(name, host_str, port_str, port, &error)) +  if (m_listen_sockets.size() == 0) { +    error.SetErrorString("No open listening sockets!");      return error; +  } -  const sa_family_t family = kDomain; -  const int socktype = kType; -  const int protocol = IPPROTO_TCP; -  SocketAddress listen_addr; -  if (host_str.empty()) -    listen_addr.SetToLocalhost(family, port); -  else if (host_str.compare("*") == 0) -    listen_addr.SetToAnyAddress(family, port); -  else { -    if (!listen_addr.getaddrinfo(host_str.c_str(), port_str.c_str(), family, -                                 socktype, protocol)) { -      error.SetErrorStringWithFormat("unable to resolve hostname '%s'", -                                     host_str.c_str()); +  int sock = -1; +  int listen_sock = -1; +  lldb_private::SocketAddress AcceptAddr; +  MainLoop accept_loop; +  std::vector<MainLoopBase::ReadHandleUP> handles; +  for (auto socket : m_listen_sockets) { +    auto fd = socket.first; +    auto inherit = this->m_child_processes_inherit; +    auto io_sp = IOObjectSP(new TCPSocket(socket.first, false, inherit)); +    handles.emplace_back(accept_loop.RegisterReadObject( +        io_sp, [fd, inherit, &sock, &AcceptAddr, &error, +                        &listen_sock](MainLoopBase &loop) { +          socklen_t sa_len = AcceptAddr.GetMaxLength(); +          sock = AcceptSocket(fd, &AcceptAddr.sockaddr(), &sa_len, inherit, +                              error); +          listen_sock = fd; +          loop.RequestTermination(); +        }, error)); +    if (error.Fail())        return error; -    }    }    bool accept_connection = false;    std::unique_ptr<TCPSocket> accepted_socket; -    // Loop until we are happy with our connection    while (!accept_connection) { -    struct sockaddr_in accept_addr; -    ::memset(&accept_addr, 0, sizeof accept_addr); -#if !(defined(__linux__) || defined(_WIN32)) -    accept_addr.sin_len = sizeof accept_addr; -#endif -    socklen_t accept_addr_len = sizeof accept_addr; - -    int sock = AcceptSocket(GetNativeSocket(), (struct sockaddr *)&accept_addr, -                            &accept_addr_len, child_processes_inherit, error); - +    accept_loop.Run(); +          if (error.Fail()) -      break; - -    bool is_same_addr = true; -#if !(defined(__linux__) || (defined(_WIN32))) -    is_same_addr = (accept_addr_len == listen_addr.sockaddr_in().sin_len); -#endif -    if (is_same_addr) -      is_same_addr = (accept_addr.sin_addr.s_addr == -                      listen_addr.sockaddr_in().sin_addr.s_addr); - -    if (is_same_addr || -        (listen_addr.sockaddr_in().sin_addr.s_addr == INADDR_ANY)) { -      accept_connection = true; -      accepted_socket.reset(new TCPSocket(sock, true)); -    } else { -      const uint8_t *accept_ip = (const uint8_t *)&accept_addr.sin_addr.s_addr; -      const uint8_t *listen_ip = -          (const uint8_t *)&listen_addr.sockaddr_in().sin_addr.s_addr; -      ::fprintf(stderr, "error: rejecting incoming connection from %u.%u.%u.%u " -                        "(expecting %u.%u.%u.%u)\n", -                accept_ip[0], accept_ip[1], accept_ip[2], accept_ip[3], -                listen_ip[0], listen_ip[1], listen_ip[2], listen_ip[3]); -      accepted_socket.reset(); +        return error; + +    lldb_private::SocketAddress &AddrIn = m_listen_sockets[listen_sock]; +    if (!AddrIn.IsAnyAddr() && AcceptAddr != AddrIn) { +      CLOSE_SOCKET(sock); +      llvm::errs() << llvm::formatv( +          "error: rejecting incoming connection from {0} (expecting {1})", +          AcceptAddr.GetIPAddress(), AddrIn.GetIPAddress()); +      continue;      } +    accept_connection = true; +    accepted_socket.reset(new TCPSocket(sock, *this));    }    if (!accepted_socket) | 
