diff options
Diffstat (limited to 'tests/atf_python/sys/net')
-rw-r--r-- | tests/atf_python/sys/net/Makefile | 11 | ||||
-rw-r--r-- | tests/atf_python/sys/net/__init__.py | 0 | ||||
-rwxr-xr-x | tests/atf_python/sys/net/rtsock.py | 604 | ||||
-rw-r--r-- | tests/atf_python/sys/net/tools.py | 100 | ||||
-rw-r--r-- | tests/atf_python/sys/net/vnet.py | 559 |
5 files changed, 1274 insertions, 0 deletions
diff --git a/tests/atf_python/sys/net/Makefile b/tests/atf_python/sys/net/Makefile new file mode 100644 index 000000000000..70d5b1a3284b --- /dev/null +++ b/tests/atf_python/sys/net/Makefile @@ -0,0 +1,11 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +PACKAGE=tests +FILES= __init__.py rtsock.py tools.py vnet.py + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python/sys/net + +.include <bsd.prog.mk> diff --git a/tests/atf_python/sys/net/__init__.py b/tests/atf_python/sys/net/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 --- /dev/null +++ b/tests/atf_python/sys/net/__init__.py diff --git a/tests/atf_python/sys/net/rtsock.py b/tests/atf_python/sys/net/rtsock.py new file mode 100755 index 000000000000..788e863f8b28 --- /dev/null +++ b/tests/atf_python/sys/net/rtsock.py @@ -0,0 +1,604 @@ +#!/usr/local/bin/python3 +import os +import socket +import struct +import sys +from ctypes import c_byte +from ctypes import c_char +from ctypes import c_int +from ctypes import c_long +from ctypes import c_uint32 +from ctypes import c_ulong +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + + +def roundup2(val: int, num: int) -> int: + if val % num: + return (val | (num - 1)) + 1 + else: + return val + + +class RtSockException(OSError): + pass + + +class RtConst: + RTM_VERSION = 5 + ALIGN = sizeof(c_long) + + AF_INET = socket.AF_INET + AF_INET6 = socket.AF_INET6 + AF_LINK = socket.AF_LINK + + RTA_DST = 0x1 + RTA_GATEWAY = 0x2 + RTA_NETMASK = 0x4 + RTA_GENMASK = 0x8 + RTA_IFP = 0x10 + RTA_IFA = 0x20 + RTA_AUTHOR = 0x40 + RTA_BRD = 0x80 + + RTM_ADD = 1 + RTM_DELETE = 2 + RTM_CHANGE = 3 + RTM_GET = 4 + + RTF_UP = 0x1 + RTF_GATEWAY = 0x2 + RTF_HOST = 0x4 + RTF_REJECT = 0x8 + RTF_DYNAMIC = 0x10 + RTF_MODIFIED = 0x20 + RTF_DONE = 0x40 + RTF_XRESOLVE = 0x200 + RTF_LLINFO = 0x400 + RTF_LLDATA = 0x400 + RTF_STATIC = 0x800 + RTF_BLACKHOLE = 0x1000 + RTF_PROTO2 = 0x4000 + RTF_PROTO1 = 0x8000 + RTF_PROTO3 = 0x40000 + RTF_FIXEDMTU = 0x80000 + RTF_PINNED = 0x100000 + RTF_LOCAL = 0x200000 + RTF_BROADCAST = 0x400000 + RTF_MULTICAST = 0x800000 + RTF_STICKY = 0x10000000 + RTF_RNH_LOCKED = 0x40000000 + RTF_GWFLAG_COMPAT = 0x80000000 + + RTV_MTU = 0x1 + RTV_HOPCOUNT = 0x2 + RTV_EXPIRE = 0x4 + RTV_RPIPE = 0x8 + RTV_SPIPE = 0x10 + RTV_SSTHRESH = 0x20 + RTV_RTT = 0x40 + RTV_RTTVAR = 0x80 + RTV_WEIGHT = 0x100 + + @staticmethod + def get_props(prefix: str) -> List[str]: + return [n for n in dir(RtConst) if n.startswith(prefix)] + + @staticmethod + def get_name(prefix: str, value: int) -> str: + props = RtConst.get_props(prefix) + for prop in props: + if getattr(RtConst, prop) == value: + return prop + return "U:{}:{}".format(prefix, value) + + @staticmethod + def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]: + props = RtConst.get_props(prefix) + propmap = {getattr(RtConst, prop): prop for prop in props} + v = 1 + ret = {} + while value: + if v & value: + if v in propmap: + ret[v] = propmap[v] + else: + ret[v] = hex(v) + value -= v + v *= 2 + return ret + + @staticmethod + def get_bitmask_str(prefix: str, value: int) -> str: + bmap = RtConst.get_bitmask_map(prefix, value) + return ",".join([v for k, v in bmap.items()]) + + +class RtMetrics(Structure): + _fields_ = [ + ("rmx_locks", c_ulong), + ("rmx_mtu", c_ulong), + ("rmx_hopcount", c_ulong), + ("rmx_expire", c_ulong), + ("rmx_recvpipe", c_ulong), + ("rmx_sendpipe", c_ulong), + ("rmx_ssthresh", c_ulong), + ("rmx_rtt", c_ulong), + ("rmx_rttvar", c_ulong), + ("rmx_pksent", c_ulong), + ("rmx_weight", c_ulong), + ("rmx_nhidx", c_ulong), + ("rmx_filler", c_ulong * 2), + ] + + +class RtMsgHdr(Structure): + _fields_ = [ + ("rtm_msglen", c_ushort), + ("rtm_version", c_byte), + ("rtm_type", c_byte), + ("rtm_index", c_ushort), + ("_rtm_spare1", c_ushort), + ("rtm_flags", c_int), + ("rtm_addrs", c_int), + ("rtm_pid", c_int), + ("rtm_seq", c_int), + ("rtm_errno", c_int), + ("rtm_fmask", c_int), + ("rtm_inits", c_ulong), + ("rtm_rmx", RtMetrics), + ] + + +class SockaddrIn(Structure): + _fields_ = [ + ("sin_len", c_byte), + ("sin_family", c_byte), + ("sin_port", c_ushort), + ("sin_addr", c_uint32), + ("sin_zero", c_char * 8), + ] + + +class SockaddrIn6(Structure): + _fields_ = [ + ("sin6_len", c_byte), + ("sin6_family", c_byte), + ("sin6_port", c_ushort), + ("sin6_flowinfo", c_uint32), + ("sin6_addr", c_byte * 16), + ("sin6_scope_id", c_uint32), + ] + + +class SockaddrDl(Structure): + _fields_ = [ + ("sdl_len", c_byte), + ("sdl_family", c_byte), + ("sdl_index", c_ushort), + ("sdl_type", c_byte), + ("sdl_nlen", c_byte), + ("sdl_alen", c_byte), + ("sdl_slen", c_byte), + ("sdl_data", c_byte * 8), + ] + + +class SaHelper(object): + @staticmethod + def is_ipv6(ip: str) -> bool: + return ":" in ip + + @staticmethod + def ip_sa(ip: str, scopeid: int = 0) -> bytes: + if SaHelper.is_ipv6(ip): + return SaHelper.ip6_sa(ip, scopeid) + else: + return SaHelper.ip4_sa(ip) + + @staticmethod + def ip4_sa(ip: str) -> bytes: + addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder) + sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int) + return bytes(sin) + + @staticmethod + def ip6_sa(ip6: str, scopeid: int) -> bytes: + addr_bytes = (c_byte * 16)() + for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)): + addr_bytes[i] = b + sin6 = SockaddrIn6( + sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid + ) + return bytes(sin6) + + @staticmethod + def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes: + sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype) + return bytes(sa) + + @staticmethod + def pxlen4_sa(pxlen: int) -> bytes: + return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen)) + + @staticmethod + def pxlen_to_ip4(pxlen: int) -> str: + if pxlen == 32: + return "255.255.255.255" + else: + addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1) + addr_bytes = struct.pack("!I", addr) + return socket.inet_ntop(socket.AF_INET, addr_bytes) + + @staticmethod + def pxlen6_sa(pxlen: int) -> bytes: + return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen)) + + @staticmethod + def pxlen_to_ip6(pxlen: int) -> str: + ip6_b = [0] * 16 + start = 0 + while pxlen > 8: + ip6_b[start] = 0xFF + pxlen -= 8 + start += 1 + ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1) + return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b)) + + @staticmethod + def print_sa_inet(sa: bytes): + if len(sa) < 8: + raise RtSockException("IPv4 sa size too small: {}".format(len(sa))) + addr = socket.inet_ntop(socket.AF_INET, sa[4:8]) + return "{}".format(addr) + + @staticmethod + def print_sa_inet6(sa: bytes): + if len(sa) < sizeof(SockaddrIn6): + raise RtSockException("IPv6 sa size too small: {}".format(len(sa))) + addr = socket.inet_ntop(socket.AF_INET6, sa[8:24]) + scopeid = struct.unpack(">I", sa[24:28])[0] + return "{} scopeid {}".format(addr, scopeid) + + @staticmethod + def print_sa_link(sa: bytes, hd: Optional[bool] = True): + if len(sa) < sizeof(SockaddrDl): + raise RtSockException("LINK sa size too small: {}".format(len(sa))) + sdl = SockaddrDl.from_buffer_copy(sa) + if sdl.sdl_index: + ifindex = "link#{} ".format(sdl.sdl_index) + else: + ifindex = "" + if sdl.sdl_nlen: + iface_offset = 8 + if sdl.sdl_nlen + iface_offset > len(sa): + raise RtSockException( + "LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa)) + ) + ifname = "ifname:{} ".format( + bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen]) + ) + else: + ifname = "" + return "{}{}".format(ifindex, ifname) + + @staticmethod + def print_sa_unknown(sa: bytes): + return "unknown_type:{}".format(sa[1]) + + @classmethod + def print_sa(cls, sa: bytes, hd: Optional[bool] = False): + if sa[0] != len(sa): + raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa))) + + if len(sa) < 2: + raise Exception( + "sa type {} too short: {}".format( + RtConst.get_name("AF_", sa[1]), len(sa) + ) + ) + + if sa[1] == socket.AF_INET: + text = cls.print_sa_inet(sa) + elif sa[1] == socket.AF_INET6: + text = cls.print_sa_inet6(sa) + elif sa[1] == socket.AF_LINK: + text = cls.print_sa_link(sa) + else: + text = cls.print_sa_unknown(sa) + if hd: + dump = " [{!r}]".format(sa) + else: + dump = "" + return "{}{}".format(text, dump) + + +class BaseRtsockMessage(object): + def __init__(self, rtm_type): + self.rtm_type = rtm_type + self.sa = SaHelper() + + @staticmethod + def print_rtm_type(rtm_type): + return RtConst.get_name("RTM_", rtm_type) + + @property + def rtm_type_str(self): + return self.print_rtm_type(self.rtm_type) + + +class RtsockRtMessage(BaseRtsockMessage): + messages = [ + RtConst.RTM_ADD, + RtConst.RTM_DELETE, + RtConst.RTM_CHANGE, + RtConst.RTM_GET, + ] + + def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None): + super().__init__(rtm_type) + self.rtm_flags = 0 + self.rtm_seq = rtm_seq + self._attrs = {} + self.rtm_errno = 0 + self.rtm_pid = 0 + self.rtm_inits = 0 + self.rtm_rmx = RtMetrics() + self._orig_data = None + if dst_sa: + self.add_sa_attr(RtConst.RTA_DST, dst_sa) + if mask_sa: + self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa) + + def add_sa_attr(self, attr_type, attr_bytes: bytes): + self._attrs[attr_type] = attr_bytes + + def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0): + if ":" in ip_addr: + self.add_ip6_attr(attr_type, ip_addr, scopeid) + else: + self.add_ip4_attr(attr_type, ip_addr) + + def add_ip4_attr(self, attr_type, ip: str): + self.add_sa_attr(attr_type, self.sa.ip_sa(ip)) + + def add_ip6_attr(self, attr_type, ip6: str, scopeid: int): + self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid)) + + def add_link_attr(self, attr_type, ifindex: Optional[int] = 0): + self.add_sa_attr(attr_type, self.sa.link_sa(ifindex)) + + def get_sa(self, attr_type) -> bytes: + return self._attrs.get(attr_type) + + def print_message(self): + # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC> + if self._orig_data: + rtm_len = len(self._orig_data) + else: + rtm_len = len(bytes(self)) + print( + "{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format( + self.rtm_type_str, + rtm_len, + self.rtm_pid, + self.rtm_seq, + self.rtm_errno, + RtConst.get_bitmask_str("RTF_", self.rtm_flags), + ) + ) + rtm_addrs = sum(list(self._attrs.keys())) + print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs))) + for attr in sorted(self._attrs.keys()): + sa_data = SaHelper.print_sa(self._attrs[attr]) + print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data)) + + def print_in_message(self): + print("vvvvvvvv IN vvvvvvvv") + self.print_message() + print() + + def verify_sa_inet(self, sa_data): + if len(sa_data) < 8: + raise Exception("IPv4 sa size too small: {}".format(sa_data)) + if sa_data[0] > len(sa_data): + raise Exception( + "IPv4 sin_len too big: {} vs sa size {}: {}".format( + sa_data[0], len(sa_data), sa_data + ) + ) + sin = SockaddrIn.from_buffer_copy(sa_data) + assert sin.sin_port == 0 + assert sin.sin_zero == [0] * 8 + + def compare_sa(self, sa_type, sa_data): + if len(sa_data) < 4: + sa_type_name = RtConst.get_name("RTA_", sa_type) + raise Exception( + "sa_len for type {} too short: {}".format(sa_type_name, len(sa_data)) + ) + our_sa = self._attrs[sa_type] + assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa) + assert len(sa_data) == len(our_sa) + assert sa_data == our_sa + + def verify(self, rtm_type: int, rtm_sa): + assert self.rtm_type_str == self.print_rtm_type(rtm_type) + assert self.rtm_errno == 0 + hdr = RtMsgHdr.from_buffer_copy(self._orig_data) + assert hdr._rtm_spare1 == 0 + for sa_type, sa_data in rtm_sa.items(): + if sa_type not in self._attrs: + sa_type_name = RtConst.get_name("RTA_", sa_type) + raise Exception("SA type {} not present".format(sa_type_name)) + self.compare_sa(sa_type, sa_data) + + @classmethod + def from_bytes(cls, data: bytes): + if len(data) < sizeof(RtMsgHdr): + raise Exception( + "messages size {} is less than expected {}".format( + len(data), sizeof(RtMsgHdr) + ) + ) + hdr = RtMsgHdr.from_buffer_copy(data) + + self = cls(hdr.rtm_type) + self.rtm_flags = hdr.rtm_flags + self.rtm_seq = hdr.rtm_seq + self.rtm_errno = hdr.rtm_errno + self.rtm_pid = hdr.rtm_pid + self.rtm_inits = hdr.rtm_inits + self.rtm_rmx = hdr.rtm_rmx + self._orig_data = data + + off = sizeof(RtMsgHdr) + v = 1 + addrs_mask = hdr.rtm_addrs + while addrs_mask: + if addrs_mask & v: + addrs_mask -= v + + if off + data[off] > len(data): + raise Exception( + "SA sizeof for {} > total message length: {}+{} > {}".format( + RtConst.get_name("RTA_", v), off, data[off], len(data) + ) + ) + self._attrs[v] = data[off : off + data[off]] + off += roundup2(data[off], RtConst.ALIGN) + v *= 2 + return self + + def __bytes__(self): + sz = sizeof(RtMsgHdr) + addrs_mask = 0 + for k, v in self._attrs.items(): + sz += roundup2(len(v), RtConst.ALIGN) + addrs_mask += k + hdr = RtMsgHdr( + rtm_msglen=sz, + rtm_version=RtConst.RTM_VERSION, + rtm_type=self.rtm_type, + rtm_flags=self.rtm_flags, + rtm_seq=self.rtm_seq, + rtm_addrs=addrs_mask, + rtm_inits=self.rtm_inits, + rtm_rmx=self.rtm_rmx, + ) + buf = bytearray(sz) + buf[0 : sizeof(RtMsgHdr)] = hdr + off = sizeof(RtMsgHdr) + for attr in sorted(self._attrs.keys()): + v = self._attrs[attr] + sa_len = len(v) + buf[off : off + sa_len] = v + off += roundup2(len(v), RtConst.ALIGN) + return bytes(buf) + + +class Rtsock: + def __init__(self): + self.socket = self._setup_rtsock() + self.rtm_seq = 1 + self.msgmap = self.build_msgmap() + + def build_msgmap(self): + classes = [RtsockRtMessage] + xmap = {} + for cls in classes: + for message in cls.messages: + xmap[message] = cls + return xmap + + def get_seq(self): + ret = self.rtm_seq + self.rtm_seq += 1 + return ret + + def get_weight(self, weight) -> int: + if weight: + return weight + else: + return 1 # RT_DEFAULT_WEIGHT + + def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]): + px = prefix.split("/") + addr_sa = SaHelper.ip_sa(px[0]) + if len(px) > 1: + pxlen = int(px[1]) + if SaHelper.is_ipv6(px[0]): + mask_sa = SaHelper.pxlen6_sa(pxlen) + else: + mask_sa = SaHelper.pxlen4_sa(pxlen) + else: + mask_sa = None + msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa) + if isinstance(gw, bytes): + msg.add_sa_attr(RtConst.RTA_GATEWAY, gw) + else: + # String + msg.add_ip_attr(RtConst.RTA_GATEWAY, gw) + return msg + + def new_rtm_add(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw) + + def new_rtm_del(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw) + + def new_rtm_change(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw) + + def _setup_rtsock(self) -> socket.socket: + s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC) + s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1) + return s + + def print_hd(self, data: bytes): + width = 16 + print("==========================================") + for chunk in [data[i : i + width] for i in range(0, len(data), width)]: + for b in chunk: + print("0x{:02X} ".format(b), end="") + print() + print() + + def write_message(self, msg): + print("vvvvvvvv OUT vvvvvvvv") + msg.print_message() + print() + msg_bytes = bytes(msg) + ret = os.write(self.socket.fileno(), msg_bytes) + if ret != -1: + assert ret == len(msg_bytes) + + def parse_message(self, data: bytes): + if len(data) < 4: + raise OSError("Short read from rtsock: {} bytes".format(len(data))) + rtm_type = data[4] + if rtm_type not in self.msgmap: + return None + + def write_data(self, data: bytes): + self.socket.send(data) + + def read_data(self, seq: Optional[int] = None) -> bytes: + while True: + data = self.socket.recv(4096) + if seq is None: + break + if len(data) > sizeof(RtMsgHdr): + hdr = RtMsgHdr.from_buffer_copy(data) + if hdr.rtm_seq == seq: + break + return data + + def read_message(self) -> bytes: + data = self.read_data() + return self.parse_message(data) diff --git a/tests/atf_python/sys/net/tools.py b/tests/atf_python/sys/net/tools.py new file mode 100644 index 000000000000..44bd74d8578f --- /dev/null +++ b/tests/atf_python/sys/net/tools.py @@ -0,0 +1,100 @@ +#!/usr/local/bin/python3 +import json +import os +import subprocess + + +class ToolsHelper(object): + NETSTAT_PATH = "/usr/bin/netstat" + IFCONFIG_PATH = "/sbin/ifconfig" + + @classmethod + def get_output(cls, cmd: str, verbose=False) -> str: + if verbose: + print("run: '{}'".format(cmd)) + return os.popen(cmd).read() + + @classmethod + def pf_rules(cls, rules, verbose=True): + pf_conf = "" + for r in rules: + pf_conf = pf_conf + r + "\n" + + if verbose: + print("Set rules:") + print(pf_conf) + + ps = subprocess.Popen("/sbin/pfctl -g -f -", shell=True, + stdin=subprocess.PIPE) + ps.communicate(bytes(pf_conf, 'utf-8')) + ret = ps.wait() + if ret != 0: + raise Exception("Failed to set pf rules %d" % ret) + + if verbose: + cls.print_output("/sbin/pfctl -sr") + + @classmethod + def print_output(cls, cmd: str, verbose=True): + if verbose: + print("======= {} =====".format(cmd)) + print(cls.get_output(cmd)) + if verbose: + print() + + @classmethod + def print_net_debug(cls): + cls.print_output("ifconfig") + cls.print_output("netstat -rnW") + + @classmethod + def set_sysctl(cls, oid, val): + cls.get_output("sysctl {}={}".format(oid, val)) + + @classmethod + def get_routes(cls, family: str, fibnum: int = 0): + family_key = {"inet": "-4", "inet6": "-6"}.get(family) + out = cls.get_output( + "{} {} -rnW -F {} --libxo json".format(cls.NETSTAT_PATH, family_key, fibnum) + ) + js = json.loads(out) + js = js["statistics"]["route-information"]["route-table"]["rt-family"] + if js: + return js[0]["rt-entry"] + else: + return [] + + @classmethod + def get_nhops(cls, family: str, fibnum: int = 0): + family_key = {"inet": "-4", "inet6": "-6"}.get(family) + out = cls.get_output( + "{} {} -onW -F {} --libxo json".format(cls.NETSTAT_PATH, family_key, fibnum) + ) + js = json.loads(out) + js = js["statistics"]["route-nhop-information"]["nhop-table"]["rt-family"] + if js: + return js[0]["nh-entry"] + else: + return [] + + @classmethod + def get_linklocals(cls): + ret = {} + ifname = None + ips = [] + for line in cls.get_output(cls.IFCONFIG_PATH).splitlines(): + if line[0].isalnum(): + if ifname: + ret[ifname] = ips + ips = [] + ifname = line.split(":")[0] + else: + words = line.split() + if words[0] == "inet6" and words[1].startswith("fe80"): + # inet6 fe80::1%lo0 prefixlen 64 scopeid 0x2 + ip = words[1].split("%")[0] + scopeid = int(words[words.index("scopeid") + 1], 16) + ips.append((ip, scopeid)) + if ifname: + ret[ifname] = ips + return ret diff --git a/tests/atf_python/sys/net/vnet.py b/tests/atf_python/sys/net/vnet.py new file mode 100644 index 000000000000..f75a3eaa693e --- /dev/null +++ b/tests/atf_python/sys/net/vnet.py @@ -0,0 +1,559 @@ +#!/usr/local/bin/python3 +import copy +import ipaddress +import os +import re +import socket +import sys +import time +from multiprocessing import connection +from multiprocessing import Pipe +from multiprocessing import Process +from typing import Dict +from typing import List +from typing import NamedTuple + +from atf_python.sys.net.tools import ToolsHelper +from atf_python.utils import BaseTest +from atf_python.utils import libc + + +def run_cmd(cmd: str, verbose=True) -> str: + if verbose: + print("run: '{}'".format(cmd)) + return os.popen(cmd).read() + + +def get_topology_id(test_id: str) -> str: + """ + Gets a unique topology id based on the pytest test_id. + "test_ip6_output.py::TestIP6Output::test_output6_pktinfo[ipandif]" -> + "TestIP6Output:test_output6_pktinfo[ipandif]" + """ + return ":".join(test_id.split("::")[-2:]) + + +def convert_test_name(test_name: str) -> str: + """Convert test name to a string that can be used in the file/jail names""" + ret = "" + for char in test_name: + if char.isalnum() or char in ("_", "-", ":"): + ret += char + elif char in ("["): + ret += "_" + return ret + + +class VnetInterface(object): + # defines from net/if_types.h + IFT_LOOP = 0x18 + IFT_ETHER = 0x06 + + def __init__(self, iface_alias: str, iface_name: str): + self.name = iface_name + self.alias = iface_alias + self.vnet_name = "" + self.jailed = False + self.addr_map: Dict[str, Dict] = {"inet6": {}, "inet": {}} + self.prefixes4: List[List[str]] = [] + self.prefixes6: List[List[str]] = [] + if iface_name.startswith("lo"): + self.iftype = self.IFT_LOOP + else: + self.iftype = self.IFT_ETHER + self.ether = ToolsHelper.get_output("/sbin/ifconfig %s ether | awk '/ether/ { print $2; }'" % iface_name).rstrip() + + @property + def ifindex(self): + return socket.if_nametoindex(self.name) + + @property + def first_ipv6(self): + d = self.addr_map["inet6"] + return d[next(iter(d))] + + @property + def first_ipv4(self): + d = self.addr_map["inet"] + return d[next(iter(d))] + + def set_vnet(self, vnet_name: str): + self.vnet_name = vnet_name + + def set_jailed(self, jailed: bool): + self.jailed = jailed + + def run_cmd(self, cmd, verbose=False): + if self.vnet_name and not self.jailed: + cmd = "/usr/sbin/jexec {} {}".format(self.vnet_name, cmd) + return run_cmd(cmd, verbose) + + @classmethod + def setup_loopback(cls, vnet_name: str): + lo = VnetInterface("", "lo0") + lo.set_vnet(vnet_name) + lo.setup_addr("127.0.0.1/8") + lo.turn_up() + + @classmethod + def create_iface(cls, alias_name: str, iface_name: str) -> List["VnetInterface"]: + name = run_cmd("/sbin/ifconfig {} create".format(iface_name)).rstrip() + if not name: + raise Exception("Unable to create iface {}".format(iface_name)) + if1 = cls(alias_name, name) + ret = [if1] + if name.startswith("epair"): + run_cmd("/sbin/ifconfig {} -txcsum -txcsum6".format(name)) + if2 = cls(alias_name, name[:-1] + "b") + if1.epairb = if2 + ret.append(if2); + return ret + + def setup_addr(self, _addr: str): + addr = ipaddress.ip_interface(_addr) + if addr.version == 6: + family = "inet6" + cmd = "/sbin/ifconfig {} {} {}".format(self.name, family, addr) + else: + family = "inet" + if self.addr_map[family]: + cmd = "/sbin/ifconfig {} alias {}".format(self.name, addr) + else: + cmd = "/sbin/ifconfig {} {} {}".format(self.name, family, addr) + self.run_cmd(cmd) + self.addr_map[family][str(addr.ip)] = addr + + def delete_addr(self, _addr: str): + addr = ipaddress.ip_address(_addr) + if addr.version == 6: + family = "inet6" + cmd = "/sbin/ifconfig {} inet6 {} delete".format(self.name, addr) + else: + family = "inet" + cmd = "/sbin/ifconfig {} -alias {}".format(self.name, addr) + self.run_cmd(cmd) + del self.addr_map[family][str(addr)] + + def turn_up(self): + cmd = "/sbin/ifconfig {} up".format(self.name) + self.run_cmd(cmd) + + def enable_ipv6(self): + cmd = "/usr/sbin/ndp -i {} -- -disabled".format(self.name) + self.run_cmd(cmd) + + def has_tentative(self) -> bool: + """True if an interface has some addresses in tenative state""" + cmd = "/sbin/ifconfig {} inet6".format(self.name) + out = self.run_cmd(cmd, verbose=False) + for line in out.splitlines(): + if "tentative" in line: + return True + return False + + +class IfaceFactory(object): + INTERFACES_FNAME = "created_ifaces.lst" + AUTODELETE_TYPES = ("epair", "gif", "gre", "lo", "tap", "tun") + + def __init__(self): + self.file_name = self.INTERFACES_FNAME + + def _register_iface(self, iface_name: str): + with open(self.file_name, "a") as f: + f.write(iface_name + "\n") + + def _list_ifaces(self) -> List[str]: + ret: List[str] = [] + try: + with open(self.file_name, "r") as f: + for line in f: + ret.append(line.strip()) + except OSError: + pass + return ret + + def create_iface(self, alias_name: str, iface_name: str) -> List[VnetInterface]: + ifaces = VnetInterface.create_iface(alias_name, iface_name) + for iface in ifaces: + if not self.is_autodeleted(iface.name): + self._register_iface(iface.name) + return ifaces + + @staticmethod + def is_autodeleted(iface_name: str) -> bool: + if iface_name == "lo0": + return False + iface_type = re.split(r"\d+", iface_name)[0] + return iface_type in IfaceFactory.AUTODELETE_TYPES + + def cleanup_vnet_interfaces(self, vnet_name: str) -> List[str]: + """Destroys""" + ifaces_lst = ToolsHelper.get_output( + "/usr/sbin/jexec {} /sbin/ifconfig -l".format(vnet_name) + ) + for iface_name in ifaces_lst.split(): + if not self.is_autodeleted(iface_name): + if iface_name not in self._list_ifaces(): + print("Skipping interface {}:{}".format(vnet_name, iface_name)) + continue + run_cmd( + "/usr/sbin/jexec {} /sbin/ifconfig {} destroy".format(vnet_name, iface_name) + ) + + def cleanup(self): + try: + os.unlink(self.INTERFACES_FNAME) + except OSError: + pass + + +class VnetInstance(object): + def __init__( + self, vnet_alias: str, vnet_name: str, jid: int, ifaces: List[VnetInterface] + ): + self.name = vnet_name + self.alias = vnet_alias # reference in the test topology + self.jid = jid + self.ifaces = ifaces + self.iface_alias_map = {} # iface.alias: iface + self.iface_map = {} # iface.name: iface + for iface in ifaces: + iface.set_vnet(vnet_name) + iface.set_jailed(True) + self.iface_alias_map[iface.alias] = iface + self.iface_map[iface.name] = iface + # Allow reference to interfce aliases as attributes + setattr(self, iface.alias, iface) + self.need_dad = False # Disable duplicate address detection by default + self.attached = False + self.pipe = None + self.subprocess = None + + def run_vnet_cmd(self, cmd, verbose=True): + if not self.attached: + cmd = "/usr/sbin/jexec {} {}".format(self.name, cmd) + return run_cmd(cmd, verbose) + + def disable_dad(self): + self.run_vnet_cmd("/sbin/sysctl net.inet6.ip6.dad_count=0") + + def set_pipe(self, pipe): + self.pipe = pipe + + def set_subprocess(self, p): + self.subprocess = p + + @staticmethod + def attach_jid(jid: int): + error_code = libc.jail_attach(jid) + if error_code != 0: + raise Exception("jail_attach() failed: errno {}".format(error_code)) + + def attach(self): + self.attach_jid(self.jid) + self.attached = True + + +class VnetFactory(object): + JAILS_FNAME = "created_jails.lst" + + def __init__(self, topology_id: str): + self.topology_id = topology_id + self.file_name = self.JAILS_FNAME + self._vnets: List[str] = [] + + def _register_vnet(self, vnet_name: str): + self._vnets.append(vnet_name) + with open(self.file_name, "a") as f: + f.write(vnet_name + "\n") + + @staticmethod + def _wait_interfaces(vnet_name: str, ifaces: List[str]) -> List[str]: + cmd = "/usr/sbin/jexec {} /sbin/ifconfig -l".format(vnet_name) + not_matched: List[str] = [] + for i in range(50): + vnet_ifaces = run_cmd(cmd).strip().split(" ") + not_matched = [] + for iface_name in ifaces: + if iface_name not in vnet_ifaces: + not_matched.append(iface_name) + if len(not_matched) == 0: + return [] + time.sleep(0.1) + return not_matched + + def create_vnet(self, vnet_alias: str, ifaces: List[VnetInterface], opts: List[str]): + vnet_name = "pytest:{}".format(convert_test_name(self.topology_id)) + if self._vnets: + # add number to distinguish jails + vnet_name = "{}_{}".format(vnet_name, len(self._vnets) + 1) + iface_cmds = " ".join(["vnet.interface={}".format(i.name) for i in ifaces]) + opt_cmds = " ".join(["{}".format(i) for i in opts]) + cmd = "/usr/sbin/jail -i -c name={} persist vnet {} {}".format( + vnet_name, iface_cmds, opt_cmds + ) + jid = 0 + try: + jid_str = run_cmd(cmd) + jid = int(jid_str) + except ValueError: + print("Jail creation failed, output: {}".format(jid_str)) + raise + self._register_vnet(vnet_name) + + # Run expedited version of routing + VnetInterface.setup_loopback(vnet_name) + + not_found = self._wait_interfaces(vnet_name, [i.name for i in ifaces]) + if not_found: + raise Exception( + "Interfaces {} has not appeared in vnet {}".format(not_found, vnet_name) + ) + return VnetInstance(vnet_alias, vnet_name, jid, ifaces) + + def cleanup(self): + iface_factory = IfaceFactory() + try: + with open(self.file_name) as f: + for line in f: + vnet_name = line.strip() + iface_factory.cleanup_vnet_interfaces(vnet_name) + run_cmd("/usr/sbin/jail -r {}".format(vnet_name)) + os.unlink(self.JAILS_FNAME) + except OSError: + pass + + +class SingleInterfaceMap(NamedTuple): + ifaces: List[VnetInterface] + vnet_aliases: List[str] + + +class ObjectsMap(NamedTuple): + iface_map: Dict[str, SingleInterfaceMap] # keyed by ifX + vnet_map: Dict[str, VnetInstance] # keyed by vnetX + topo_map: Dict # self.TOPOLOGY + + +class VnetTestTemplate(BaseTest): + NEED_ROOT: bool = True + TOPOLOGY = {} + + def _require_default_modules(self): + libc.kldload("if_epair.ko") + self.require_module("if_epair") + + def _get_vnet_handler(self, vnet_alias: str): + handler_name = "{}_handler".format(vnet_alias) + return getattr(self, handler_name, None) + + def _setup_vnet(self, vnet: VnetInstance, obj_map: Dict, pipe): + """Base Handler to setup given VNET. + Can be run in a subprocess. If so, passes control to the special + vnetX_handler() after setting up interface addresses + """ + vnet.attach() + print("# setup_vnet({})".format(vnet.name)) + if pipe is not None: + vnet.set_pipe(pipe) + + topo = obj_map.topo_map + ipv6_ifaces = [] + # Disable DAD + if not vnet.need_dad: + vnet.disable_dad() + for iface in vnet.ifaces: + # check index of vnet within an interface + # as we have prefixes for both ends of the interface + iface_map = obj_map.iface_map[iface.alias] + idx = iface_map.vnet_aliases.index(vnet.alias) + prefixes6 = topo[iface.alias].get("prefixes6", []) + prefixes4 = topo[iface.alias].get("prefixes4", []) + if prefixes6 or prefixes4: + ipv6_ifaces.append(iface) + iface.turn_up() + if prefixes6: + iface.enable_ipv6() + for prefix in prefixes6 + prefixes4: + if prefix[idx]: + iface.setup_addr(prefix[idx]) + for iface in ipv6_ifaces: + while iface.has_tentative(): + time.sleep(0.1) + + # Run actual handler + handler = self._get_vnet_handler(vnet.alias) + if handler: + # Do unbuffered stdout for children + # so the logs are present if the child hangs + sys.stdout.reconfigure(line_buffering=True) + self.drop_privileges() + handler(vnet) + + def _get_topo_ifmap(self, topo: Dict): + iface_factory = IfaceFactory() + iface_map: Dict[str, SingleInterfaceMap] = {} + iface_aliases = set() + for obj_name, obj_data in topo.items(): + if obj_name.startswith("vnet"): + for iface_alias in obj_data["ifaces"]: + iface_aliases.add(iface_alias) + for iface_alias in iface_aliases: + print("Creating {}".format(iface_alias)) + iface_data = topo[iface_alias] + iface_type = iface_data.get("type", "epair") + ifaces = iface_factory.create_iface(iface_alias, iface_type) + smap = SingleInterfaceMap(ifaces, []) + iface_map[iface_alias] = smap + return iface_map + + def setup_topology(self, topo: Dict, topology_id: str): + """Creates jails & interfaces for the provided topology""" + vnet_map = {} + vnet_factory = VnetFactory(topology_id) + iface_map = self._get_topo_ifmap(topo) + for obj_name, obj_data in topo.items(): + if obj_name.startswith("vnet"): + vnet_ifaces = [] + for iface_alias in obj_data["ifaces"]: + # epair creates 2 interfaces, grab first _available_ + # and map it to the VNET being created + idx = len(iface_map[iface_alias].vnet_aliases) + iface_map[iface_alias].vnet_aliases.append(obj_name) + vnet_ifaces.append(iface_map[iface_alias].ifaces[idx]) + opts = [] + if "opts" in obj_data: + opts = obj_data["opts"] + vnet = vnet_factory.create_vnet(obj_name, vnet_ifaces, opts) + vnet_map[obj_name] = vnet + # Allow reference to VNETs as attributes + setattr(self, obj_name, vnet) + # Debug output + print("============= TEST TOPOLOGY =============") + for vnet_alias, vnet in vnet_map.items(): + print("# vnet {} -> {}".format(vnet.alias, vnet.name), end="") + handler = self._get_vnet_handler(vnet.alias) + if handler: + print(" handler: {}".format(handler.__name__), end="") + print() + for iface_alias, iface_data in iface_map.items(): + vnets = iface_data.vnet_aliases + ifaces: List[VnetInterface] = iface_data.ifaces + if len(vnets) == 1 and len(ifaces) == 2: + print( + "# iface {}: {}::{} -> main::{}".format( + iface_alias, vnets[0], ifaces[0].name, ifaces[1].name + ) + ) + elif len(vnets) == 2 and len(ifaces) == 2: + print( + "# iface {}: {}::{} -> {}::{}".format( + iface_alias, vnets[0], ifaces[0].name, vnets[1], ifaces[1].name + ) + ) + else: + print( + "# iface {}: ifaces: {} vnets: {}".format( + iface_alias, vnets, [i.name for i in ifaces] + ) + ) + print() + return ObjectsMap(iface_map, vnet_map, topo) + + def setup_method(self, _method): + """Sets up all the required topology and handlers for the given test""" + super().setup_method(_method) + self._require_default_modules() + + # TestIP6Output.test_output6_pktinfo[ipandif] + topology_id = get_topology_id(self.test_id) + topology = self.TOPOLOGY + # First, setup kernel objects - interfaces & vnets + obj_map = self.setup_topology(topology, topology_id) + main_vnet = None # one without subprocess handler + for vnet_alias, vnet in obj_map.vnet_map.items(): + if self._get_vnet_handler(vnet_alias): + # Need subprocess to run + parent_pipe, child_pipe = Pipe() + p = Process( + target=self._setup_vnet, + args=( + vnet, + obj_map, + child_pipe, + ), + ) + vnet.set_pipe(parent_pipe) + vnet.set_subprocess(p) + p.start() + else: + if main_vnet is not None: + raise Exception("there can be only 1 VNET w/o handler") + main_vnet = vnet + # Main vnet needs to be the last, so all the other subprocesses + # are started & their pipe handles collected + self.vnet = main_vnet + self._setup_vnet(main_vnet, obj_map, None) + # Save state for the main handler + self.iface_map = obj_map.iface_map + self.vnet_map = obj_map.vnet_map + self.drop_privileges() + + def cleanup(self, test_id: str): + # pytest test id: file::class::test_name + topology_id = get_topology_id(self.test_id) + + print("============= vnet cleanup =============") + print("# topology_id: '{}'".format(topology_id)) + VnetFactory(topology_id).cleanup() + IfaceFactory().cleanup() + + def wait_object(self, pipe, timeout=5): + if pipe.poll(timeout): + return pipe.recv() + raise TimeoutError + + def wait_objects_any(self, pipe_list, timeout=5): + objects = connection.wait(pipe_list, timeout) + if objects: + return objects[0].recv() + raise TimeoutError + + def send_object(self, pipe, obj): + pipe.send(obj) + + def wait(self): + while True: + time.sleep(1) + + @property + def curvnet(self): + pass + + +class SingleVnetTestTemplate(VnetTestTemplate): + IPV6_PREFIXES: List[str] = [] + IPV4_PREFIXES: List[str] = [] + IFTYPE = "epair" + + def _setup_default_topology(self): + topology = copy.deepcopy( + { + "vnet1": {"ifaces": ["if1"]}, + "if1": {"type": self.IFTYPE, "prefixes4": [], "prefixes6": []}, + } + ) + for prefix in self.IPV6_PREFIXES: + topology["if1"]["prefixes6"].append((prefix,)) + for prefix in self.IPV4_PREFIXES: + topology["if1"]["prefixes4"].append((prefix,)) + return topology + + def setup_method(self, method): + if not getattr(self, "TOPOLOGY", None): + self.TOPOLOGY = self._setup_default_topology() + else: + names = self.TOPOLOGY.keys() + assert len([n for n in names if n.startswith("vnet")]) == 1 + super().setup_method(method) |