diff options
| author | Alexander V. Chernikov <melifaro@FreeBSD.org> | 2022-06-25 19:01:45 +0000 |
|---|---|---|
| committer | Alexander V. Chernikov <melifaro@FreeBSD.org> | 2022-06-25 19:25:15 +0000 |
| commit | 8eb2bee6c0f4957c6c1cea826e59cda4d18a2a64 (patch) | |
| tree | 8a6481d536e076810de128b0ba49cece8b671554 /tests/atf_python | |
| parent | c38da70c28a886cc31a2f009baa79deb7fceec88 (diff) | |
Diffstat (limited to 'tests/atf_python')
| -rw-r--r-- | tests/atf_python/Makefile | 12 | ||||
| -rw-r--r-- | tests/atf_python/__init__.py | 4 | ||||
| -rw-r--r-- | tests/atf_python/atf_pytest.py | 218 | ||||
| -rw-r--r-- | tests/atf_python/sys/Makefile | 11 | ||||
| -rw-r--r-- | tests/atf_python/sys/__init__.py | 0 | ||||
| -rw-r--r-- | tests/atf_python/sys/net/Makefile | 10 | ||||
| -rw-r--r-- | tests/atf_python/sys/net/__init__.py | 0 | ||||
| -rwxr-xr-x | tests/atf_python/sys/net/rtsock.py | 604 | ||||
| -rw-r--r-- | tests/atf_python/sys/net/tools.py | 33 | ||||
| -rw-r--r-- | tests/atf_python/sys/net/vnet.py | 203 |
10 files changed, 1095 insertions, 0 deletions
diff --git a/tests/atf_python/Makefile b/tests/atf_python/Makefile new file mode 100644 index 000000000000..26d419743257 --- /dev/null +++ b/tests/atf_python/Makefile @@ -0,0 +1,12 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +FILES= __init__.py atf_pytest.py +SUBDIR= sys + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python + + +.include <bsd.prog.mk> diff --git a/tests/atf_python/__init__.py b/tests/atf_python/__init__.py new file mode 100644 index 000000000000..6d5ec22ef054 --- /dev/null +++ b/tests/atf_python/__init__.py @@ -0,0 +1,4 @@ +import pytest + +pytest.register_assert_rewrite("atf_python.sys.net.rtsock") +pytest.register_assert_rewrite("atf_python.sys.net.vnet") diff --git a/tests/atf_python/atf_pytest.py b/tests/atf_python/atf_pytest.py new file mode 100644 index 000000000000..89c0e3a515b9 --- /dev/null +++ b/tests/atf_python/atf_pytest.py @@ -0,0 +1,218 @@ +import types +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Tuple + +import pytest + + +class ATFCleanupItem(pytest.Item): + def runtest(self): + """Runs cleanup procedure for the test instead of the test""" + instance = self.parent.cls() + instance.cleanup(self.nodeid) + + def setup_method_noop(self, method): + """Overrides runtest setup method""" + pass + + def teardown_method_noop(self, method): + """Overrides runtest teardown method""" + pass + + +class ATFTestObj(object): + def __init__(self, obj, has_cleanup): + # Use nodeid without name to properly name class-derived tests + self.ident = obj.nodeid.split("::", 1)[1] + self.description = self._get_test_description(obj) + self.has_cleanup = has_cleanup + self.obj = obj + + def _get_test_description(self, obj): + """Returns first non-empty line from func docstring or func name""" + docstr = obj.function.__doc__ + if docstr: + for line in docstr.split("\n"): + if line: + return line + return obj.name + + def _convert_marks(self, obj) -> Dict[str, Any]: + wj_func = lambda x: " ".join(x) # noqa: E731 + _map: Dict[str, Dict] = { + "require_user": {"name": "require.user"}, + "require_arch": {"name": "require.arch", "fmt": wj_func}, + "require_diskspace": {"name": "require.diskspace"}, + "require_files": {"name": "require.files", "fmt": wj_func}, + "require_machine": {"name": "require.machine", "fmt": wj_func}, + "require_memory": {"name": "require.memory"}, + "require_progs": {"name": "require.progs", "fmt": wj_func}, + "timeout": {}, + } + ret = {} + for mark in obj.iter_markers(): + if mark.name in _map: + name = _map[mark.name].get("name", mark.name) + if "fmt" in _map[mark.name]: + val = _map[mark.name]["fmt"](mark.args[0]) + else: + val = mark.args[0] + ret[name] = val + return ret + + def as_lines(self) -> List[str]: + """Output test definition in ATF-specific format""" + ret = [] + ret.append("ident: {}".format(self.ident)) + ret.append("descr: {}".format(self._get_test_description(self.obj))) + if self.has_cleanup: + ret.append("has.cleanup: true") + for key, value in self._convert_marks(self.obj).items(): + ret.append("{}: {}".format(key, value)) + return ret + + +class ATFHandler(object): + class ReportState(NamedTuple): + state: str + reason: str + + def __init__(self): + self._tests_state_map: Dict[str, ReportStatus] = {} + + def override_runtest(self, obj): + # Override basic runtest command + obj.runtest = types.MethodType(ATFCleanupItem.runtest, obj) + # Override class setup/teardown + obj.parent.cls.setup_method = ATFCleanupItem.setup_method_noop + obj.parent.cls.teardown_method = ATFCleanupItem.teardown_method_noop + + def get_object_cleanup_class(self, obj): + if hasattr(obj, "parent") and obj.parent is not None: + if hasattr(obj.parent, "cls") and obj.parent.cls is not None: + if hasattr(obj.parent.cls, "cleanup"): + return obj.parent.cls + return None + + def has_object_cleanup(self, obj): + return self.get_object_cleanup_class(obj) is not None + + def list_tests(self, tests: List[str]): + print('Content-Type: application/X-atf-tp; version="1"') + print() + for test_obj in tests: + has_cleanup = self.has_object_cleanup(test_obj) + atf_test = ATFTestObj(test_obj, has_cleanup) + for line in atf_test.as_lines(): + print(line) + print() + + def set_report_state(self, test_name: str, state: str, reason: str): + self._tests_state_map[test_name] = self.ReportState(state, reason) + + def _extract_report_reason(self, report): + data = report.longrepr + if data is None: + return None + if isinstance(data, Tuple): + # ('/path/to/test.py', 23, 'Skipped: unable to test') + reason = data[2] + for prefix in "Skipped: ": + if reason.startswith(prefix): + reason = reason[len(prefix):] + return reason + else: + # string/ traceback / exception report. Capture the last line + return str(data).split("\n")[-1] + return None + + def add_report(self, report): + # MAP pytest report state to the atf-desired state + # + # ATF test states: + # (1) expected_death, (2) expected_exit, (3) expected_failure + # (4) expected_signal, (5) expected_timeout, (6) passed + # (7) skipped, (8) failed + # + # Note that ATF don't have the concept of "soft xfail" - xpass + # is a failure. It also calls teardown routine in a separate + # process, thus teardown states (pytest-only) are handled as + # body continuation. + + # (stage, state, wasxfail) + + # Just a passing test: WANT: passed + # GOT: (setup, passed, F), (call, passed, F), (teardown, passed, F) + # + # Failing body test: WHAT: failed + # GOT: (setup, passed, F), (call, failed, F), (teardown, passed, F) + # + # pytest.skip test decorator: WANT: skipped + # GOT: (setup,skipped, False), (teardown, passed, False) + # + # pytest.skip call inside test function: WANT: skipped + # GOT: (setup, passed, F), (call, skipped, F), (teardown,passed, F) + # + # mark.xfail decorator+pytest.xfail: WANT: expected_failure + # GOT: (setup, passed, F), (call, skipped, T), (teardown, passed, F) + # + # mark.xfail decorator+pass: WANT: failed + # GOT: (setup, passed, F), (call, passed, T), (teardown, passed, F) + + test_name = report.location[2] + stage = report.when + state = report.outcome + reason = self._extract_report_reason(report) + + # We don't care about strict xfail - it gets translated to False + + if stage == "setup": + if state in ("skipped", "failed"): + # failed init -> failed test, skipped setup -> xskip + # for the whole test + self.set_report_state(test_name, state, reason) + elif stage == "call": + # "call" stage shouldn't matter if setup failed + if test_name in self._tests_state_map: + if self._tests_state_map[test_name].state == "failed": + return + if state == "failed": + # Record failure & override "skipped" state + self.set_report_state(test_name, state, reason) + elif state == "skipped": + if hasattr(reason, "wasxfail"): + # xfail() called in the test body + state = "expected_failure" + else: + # skip inside the body + pass + self.set_report_state(test_name, state, reason) + elif state == "passed": + if hasattr(reason, "wasxfail"): + # the test was expected to fail but didn't + # mark as hard failure + state = "failed" + self.set_report_state(test_name, state, reason) + elif stage == "teardown": + if state == "failed": + # teardown should be empty, as the cleanup + # procedures should be implemented as a separate + # function/method, so mark teardown failure as + # global failure + self.set_report_state(test_name, state, reason) + + def write_report(self, path): + if self._tests_state_map: + # If we're executing in ATF mode, there has to be just one test + # Anyway, deterministically pick the first one + first_test_name = next(iter(self._tests_state_map)) + test = self._tests_state_map[first_test_name] + if test.state == "passed": + line = test.state + else: + line = "{}: {}".format(test.state, test.reason) + with open(path, mode="w") as f: + print(line, file=f) diff --git a/tests/atf_python/sys/Makefile b/tests/atf_python/sys/Makefile new file mode 100644 index 000000000000..ff4cf17b85d2 --- /dev/null +++ b/tests/atf_python/sys/Makefile @@ -0,0 +1,11 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +FILES= __init__.py +SUBDIR= net + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python/sys + +.include <bsd.prog.mk> diff --git a/tests/atf_python/sys/__init__.py b/tests/atf_python/sys/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 --- /dev/null +++ b/tests/atf_python/sys/__init__.py diff --git a/tests/atf_python/sys/net/Makefile b/tests/atf_python/sys/net/Makefile new file mode 100644 index 000000000000..05b1d8afe863 --- /dev/null +++ b/tests/atf_python/sys/net/Makefile @@ -0,0 +1,10 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +FILES= __init__.py rtsock.py tools.py vnet.py + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python/sys/net + +.include <bsd.prog.mk> diff --git a/tests/atf_python/sys/net/__init__.py b/tests/atf_python/sys/net/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 --- /dev/null +++ b/tests/atf_python/sys/net/__init__.py diff --git a/tests/atf_python/sys/net/rtsock.py b/tests/atf_python/sys/net/rtsock.py new file mode 100755 index 000000000000..788e863f8b28 --- /dev/null +++ b/tests/atf_python/sys/net/rtsock.py @@ -0,0 +1,604 @@ +#!/usr/local/bin/python3 +import os +import socket +import struct +import sys +from ctypes import c_byte +from ctypes import c_char +from ctypes import c_int +from ctypes import c_long +from ctypes import c_uint32 +from ctypes import c_ulong +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + + +def roundup2(val: int, num: int) -> int: + if val % num: + return (val | (num - 1)) + 1 + else: + return val + + +class RtSockException(OSError): + pass + + +class RtConst: + RTM_VERSION = 5 + ALIGN = sizeof(c_long) + + AF_INET = socket.AF_INET + AF_INET6 = socket.AF_INET6 + AF_LINK = socket.AF_LINK + + RTA_DST = 0x1 + RTA_GATEWAY = 0x2 + RTA_NETMASK = 0x4 + RTA_GENMASK = 0x8 + RTA_IFP = 0x10 + RTA_IFA = 0x20 + RTA_AUTHOR = 0x40 + RTA_BRD = 0x80 + + RTM_ADD = 1 + RTM_DELETE = 2 + RTM_CHANGE = 3 + RTM_GET = 4 + + RTF_UP = 0x1 + RTF_GATEWAY = 0x2 + RTF_HOST = 0x4 + RTF_REJECT = 0x8 + RTF_DYNAMIC = 0x10 + RTF_MODIFIED = 0x20 + RTF_DONE = 0x40 + RTF_XRESOLVE = 0x200 + RTF_LLINFO = 0x400 + RTF_LLDATA = 0x400 + RTF_STATIC = 0x800 + RTF_BLACKHOLE = 0x1000 + RTF_PROTO2 = 0x4000 + RTF_PROTO1 = 0x8000 + RTF_PROTO3 = 0x40000 + RTF_FIXEDMTU = 0x80000 + RTF_PINNED = 0x100000 + RTF_LOCAL = 0x200000 + RTF_BROADCAST = 0x400000 + RTF_MULTICAST = 0x800000 + RTF_STICKY = 0x10000000 + RTF_RNH_LOCKED = 0x40000000 + RTF_GWFLAG_COMPAT = 0x80000000 + + RTV_MTU = 0x1 + RTV_HOPCOUNT = 0x2 + RTV_EXPIRE = 0x4 + RTV_RPIPE = 0x8 + RTV_SPIPE = 0x10 + RTV_SSTHRESH = 0x20 + RTV_RTT = 0x40 + RTV_RTTVAR = 0x80 + RTV_WEIGHT = 0x100 + + @staticmethod + def get_props(prefix: str) -> List[str]: + return [n for n in dir(RtConst) if n.startswith(prefix)] + + @staticmethod + def get_name(prefix: str, value: int) -> str: + props = RtConst.get_props(prefix) + for prop in props: + if getattr(RtConst, prop) == value: + return prop + return "U:{}:{}".format(prefix, value) + + @staticmethod + def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]: + props = RtConst.get_props(prefix) + propmap = {getattr(RtConst, prop): prop for prop in props} + v = 1 + ret = {} + while value: + if v & value: + if v in propmap: + ret[v] = propmap[v] + else: + ret[v] = hex(v) + value -= v + v *= 2 + return ret + + @staticmethod + def get_bitmask_str(prefix: str, value: int) -> str: + bmap = RtConst.get_bitmask_map(prefix, value) + return ",".join([v for k, v in bmap.items()]) + + +class RtMetrics(Structure): + _fields_ = [ + ("rmx_locks", c_ulong), + ("rmx_mtu", c_ulong), + ("rmx_hopcount", c_ulong), + ("rmx_expire", c_ulong), + ("rmx_recvpipe", c_ulong), + ("rmx_sendpipe", c_ulong), + ("rmx_ssthresh", c_ulong), + ("rmx_rtt", c_ulong), + ("rmx_rttvar", c_ulong), + ("rmx_pksent", c_ulong), + ("rmx_weight", c_ulong), + ("rmx_nhidx", c_ulong), + ("rmx_filler", c_ulong * 2), + ] + + +class RtMsgHdr(Structure): + _fields_ = [ + ("rtm_msglen", c_ushort), + ("rtm_version", c_byte), + ("rtm_type", c_byte), + ("rtm_index", c_ushort), + ("_rtm_spare1", c_ushort), + ("rtm_flags", c_int), + ("rtm_addrs", c_int), + ("rtm_pid", c_int), + ("rtm_seq", c_int), + ("rtm_errno", c_int), + ("rtm_fmask", c_int), + ("rtm_inits", c_ulong), + ("rtm_rmx", RtMetrics), + ] + + +class SockaddrIn(Structure): + _fields_ = [ + ("sin_len", c_byte), + ("sin_family", c_byte), + ("sin_port", c_ushort), + ("sin_addr", c_uint32), + ("sin_zero", c_char * 8), + ] + + +class SockaddrIn6(Structure): + _fields_ = [ + ("sin6_len", c_byte), + ("sin6_family", c_byte), + ("sin6_port", c_ushort), + ("sin6_flowinfo", c_uint32), + ("sin6_addr", c_byte * 16), + ("sin6_scope_id", c_uint32), + ] + + +class SockaddrDl(Structure): + _fields_ = [ + ("sdl_len", c_byte), + ("sdl_family", c_byte), + ("sdl_index", c_ushort), + ("sdl_type", c_byte), + ("sdl_nlen", c_byte), + ("sdl_alen", c_byte), + ("sdl_slen", c_byte), + ("sdl_data", c_byte * 8), + ] + + +class SaHelper(object): + @staticmethod + def is_ipv6(ip: str) -> bool: + return ":" in ip + + @staticmethod + def ip_sa(ip: str, scopeid: int = 0) -> bytes: + if SaHelper.is_ipv6(ip): + return SaHelper.ip6_sa(ip, scopeid) + else: + return SaHelper.ip4_sa(ip) + + @staticmethod + def ip4_sa(ip: str) -> bytes: + addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder) + sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int) + return bytes(sin) + + @staticmethod + def ip6_sa(ip6: str, scopeid: int) -> bytes: + addr_bytes = (c_byte * 16)() + for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)): + addr_bytes[i] = b + sin6 = SockaddrIn6( + sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid + ) + return bytes(sin6) + + @staticmethod + def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes: + sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype) + return bytes(sa) + + @staticmethod + def pxlen4_sa(pxlen: int) -> bytes: + return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen)) + + @staticmethod + def pxlen_to_ip4(pxlen: int) -> str: + if pxlen == 32: + return "255.255.255.255" + else: + addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1) + addr_bytes = struct.pack("!I", addr) + return socket.inet_ntop(socket.AF_INET, addr_bytes) + + @staticmethod + def pxlen6_sa(pxlen: int) -> bytes: + return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen)) + + @staticmethod + def pxlen_to_ip6(pxlen: int) -> str: + ip6_b = [0] * 16 + start = 0 + while pxlen > 8: + ip6_b[start] = 0xFF + pxlen -= 8 + start += 1 + ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1) + return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b)) + + @staticmethod + def print_sa_inet(sa: bytes): + if len(sa) < 8: + raise RtSockException("IPv4 sa size too small: {}".format(len(sa))) + addr = socket.inet_ntop(socket.AF_INET, sa[4:8]) + return "{}".format(addr) + + @staticmethod + def print_sa_inet6(sa: bytes): + if len(sa) < sizeof(SockaddrIn6): + raise RtSockException("IPv6 sa size too small: {}".format(len(sa))) + addr = socket.inet_ntop(socket.AF_INET6, sa[8:24]) + scopeid = struct.unpack(">I", sa[24:28])[0] + return "{} scopeid {}".format(addr, scopeid) + + @staticmethod + def print_sa_link(sa: bytes, hd: Optional[bool] = True): + if len(sa) < sizeof(SockaddrDl): + raise RtSockException("LINK sa size too small: {}".format(len(sa))) + sdl = SockaddrDl.from_buffer_copy(sa) + if sdl.sdl_index: + ifindex = "link#{} ".format(sdl.sdl_index) + else: + ifindex = "" + if sdl.sdl_nlen: + iface_offset = 8 + if sdl.sdl_nlen + iface_offset > len(sa): + raise RtSockException( + "LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa)) + ) + ifname = "ifname:{} ".format( + bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen]) + ) + else: + ifname = "" + return "{}{}".format(ifindex, ifname) + + @staticmethod + def print_sa_unknown(sa: bytes): + return "unknown_type:{}".format(sa[1]) + + @classmethod + def print_sa(cls, sa: bytes, hd: Optional[bool] = False): + if sa[0] != len(sa): + raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa))) + + if len(sa) < 2: + raise Exception( + "sa type {} too short: {}".format( + RtConst.get_name("AF_", sa[1]), len(sa) + ) + ) + + if sa[1] == socket.AF_INET: + text = cls.print_sa_inet(sa) + elif sa[1] == socket.AF_INET6: + text = cls.print_sa_inet6(sa) + elif sa[1] == socket.AF_LINK: + text = cls.print_sa_link(sa) + else: + text = cls.print_sa_unknown(sa) + if hd: + dump = " [{!r}]".format(sa) + else: + dump = "" + return "{}{}".format(text, dump) + + +class BaseRtsockMessage(object): + def __init__(self, rtm_type): + self.rtm_type = rtm_type + self.sa = SaHelper() + + @staticmethod + def print_rtm_type(rtm_type): + return RtConst.get_name("RTM_", rtm_type) + + @property + def rtm_type_str(self): + return self.print_rtm_type(self.rtm_type) + + +class RtsockRtMessage(BaseRtsockMessage): + messages = [ + RtConst.RTM_ADD, + RtConst.RTM_DELETE, + RtConst.RTM_CHANGE, + RtConst.RTM_GET, + ] + + def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None): + super().__init__(rtm_type) + self.rtm_flags = 0 + self.rtm_seq = rtm_seq + self._attrs = {} + self.rtm_errno = 0 + self.rtm_pid = 0 + self.rtm_inits = 0 + self.rtm_rmx = RtMetrics() + self._orig_data = None + if dst_sa: + self.add_sa_attr(RtConst.RTA_DST, dst_sa) + if mask_sa: + self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa) + + def add_sa_attr(self, attr_type, attr_bytes: bytes): + self._attrs[attr_type] = attr_bytes + + def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0): + if ":" in ip_addr: + self.add_ip6_attr(attr_type, ip_addr, scopeid) + else: + self.add_ip4_attr(attr_type, ip_addr) + + def add_ip4_attr(self, attr_type, ip: str): + self.add_sa_attr(attr_type, self.sa.ip_sa(ip)) + + def add_ip6_attr(self, attr_type, ip6: str, scopeid: int): + self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid)) + + def add_link_attr(self, attr_type, ifindex: Optional[int] = 0): + self.add_sa_attr(attr_type, self.sa.link_sa(ifindex)) + + def get_sa(self, attr_type) -> bytes: + return self._attrs.get(attr_type) + + def print_message(self): + # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC> + if self._orig_data: + rtm_len = len(self._orig_data) + else: + rtm_len = len(bytes(self)) + print( + "{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format( + self.rtm_type_str, + rtm_len, + self.rtm_pid, + self.rtm_seq, + self.rtm_errno, + RtConst.get_bitmask_str("RTF_", self.rtm_flags), + ) + ) + rtm_addrs = sum(list(self._attrs.keys())) + print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs))) + for attr in sorted(self._attrs.keys()): + sa_data = SaHelper.print_sa(self._attrs[attr]) + print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data)) + + def print_in_message(self): + print("vvvvvvvv IN vvvvvvvv") + self.print_message() + print() + + def verify_sa_inet(self, sa_data): + if len(sa_data) < 8: + raise Exception("IPv4 sa size too small: {}".format(sa_data)) + if sa_data[0] > len(sa_data): + raise Exception( + "IPv4 sin_len too big: {} vs sa size {}: {}".format( + sa_data[0], len(sa_data), sa_data + ) + ) + sin = SockaddrIn.from_buffer_copy(sa_data) + assert sin.sin_port == 0 + assert sin.sin_zero == [0] * 8 + + def compare_sa(self, sa_type, sa_data): + if len(sa_data) < 4: + sa_type_name = RtConst.get_name("RTA_", sa_type) + raise Exception( + "sa_len for type {} too short: {}".format(sa_type_name, len(sa_data)) + ) + our_sa = self._attrs[sa_type] + assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa) + assert len(sa_data) == len(our_sa) + assert sa_data == our_sa + + def verify(self, rtm_type: int, rtm_sa): + assert self.rtm_type_str == self.print_rtm_type(rtm_type) + assert self.rtm_errno == 0 + hdr = RtMsgHdr.from_buffer_copy(self._orig_data) + assert hdr._rtm_spare1 == 0 + for sa_type, sa_data in rtm_sa.items(): + if sa_type not in self._attrs: + sa_type_name = RtConst.get_name("RTA_", sa_type) + raise Exception("SA type {} not present".format(sa_type_name)) + self.compare_sa(sa_type, sa_data) + + @classmethod + def from_bytes(cls, data: bytes): + if len(data) < sizeof(RtMsgHdr): + raise Exception( + "messages size {} is less than expected {}".format( + len(data), sizeof(RtMsgHdr) + ) + ) + hdr = RtMsgHdr.from_buffer_copy(data) + + self = cls(hdr.rtm_type) + self.rtm_flags = hdr.rtm_flags + self.rtm_seq = hdr.rtm_seq + self.rtm_errno = hdr.rtm_errno + self.rtm_pid = hdr.rtm_pid + self.rtm_inits = hdr.rtm_inits + self.rtm_rmx = hdr.rtm_rmx + self._orig_data = data + + off = sizeof(RtMsgHdr) + v = 1 + addrs_mask = hdr.rtm_addrs + while addrs_mask: + if addrs_mask & v: + addrs_mask -= v + + if off + data[off] > len(data): + raise Exception( + "SA sizeof for {} > total message length: {}+{} > {}".format( + RtConst.get_name("RTA_", v), off, data[off], len(data) + ) + ) + self._attrs[v] = data[off : off + data[off]] + off += roundup2(data[off], RtConst.ALIGN) + v *= 2 + return self + + def __bytes__(self): + sz = sizeof(RtMsgHdr) + addrs_mask = 0 + for k, v in self._attrs.items(): + sz += roundup2(len(v), RtConst.ALIGN) + addrs_mask += k + hdr = RtMsgHdr( + rtm_msglen=sz, + rtm_version=RtConst.RTM_VERSION, + rtm_type=self.rtm_type, + rtm_flags=self.rtm_flags, + rtm_seq=self.rtm_seq, + rtm_addrs=addrs_mask, + rtm_inits=self.rtm_inits, + rtm_rmx=self.rtm_rmx, + ) + buf = bytearray(sz) + buf[0 : sizeof(RtMsgHdr)] = hdr + off = sizeof(RtMsgHdr) + for attr in sorted(self._attrs.keys()): + v = self._attrs[attr] + sa_len = len(v) + buf[off : off + sa_len] = v + off += roundup2(len(v), RtConst.ALIGN) + return bytes(buf) + + +class Rtsock: + def __init__(self): + self.socket = self._setup_rtsock() + self.rtm_seq = 1 + self.msgmap = self.build_msgmap() + + def build_msgmap(self): + classes = [RtsockRtMessage] + xmap = {} + for cls in classes: + for message in cls.messages: + xmap[message] = cls + return xmap + + def get_seq(self): + ret = self.rtm_seq + self.rtm_seq += 1 + return ret + + def get_weight(self, weight) -> int: + if weight: + return weight + else: + return 1 # RT_DEFAULT_WEIGHT + + def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]): + px = prefix.split("/") + addr_sa = SaHelper.ip_sa(px[0]) + if len(px) > 1: + pxlen = int(px[1]) + if SaHelper.is_ipv6(px[0]): + mask_sa = SaHelper.pxlen6_sa(pxlen) + else: + mask_sa = SaHelper.pxlen4_sa(pxlen) + else: + mask_sa = None + msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa) + if isinstance(gw, bytes): + msg.add_sa_attr(RtConst.RTA_GATEWAY, gw) + else: + # String + msg.add_ip_attr(RtConst.RTA_GATEWAY, gw) + return msg + + def new_rtm_add(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw) + + def new_rtm_del(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw) + + def new_rtm_change(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw) + + def _setup_rtsock(self) -> socket.socket: + s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC) + s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1) + return s + + def print_hd(self, data: bytes): + width = 16 + print("==========================================") + for chunk in [data[i : i + width] for i in range(0, len(data), width)]: + for b in chunk: + print("0x{:02X} ".format(b), end="") + print() + print() + + def write_message(self, msg): + print("vvvvvvvv OUT vvvvvvvv") + msg.print_message() + print() + msg_bytes = bytes(msg) + ret = os.write(self.socket.fileno(), msg_bytes) + if ret != -1: + assert ret == len(msg_bytes) + + def parse_message(self, data: bytes): + if len(data) < 4: + raise OSError("Short read from rtsock: {} bytes".format(len(data))) + rtm_type = data[4] + if rtm_type not in self.msgmap: + return None + + def write_data(self, data: bytes): + self.socket.send(data) + + def read_data(self, seq: Optional[int] = None) -> bytes: + while True: + data = self.socket.recv(4096) + if seq is None: + break + if len(data) > sizeof(RtMsgHdr): + hdr = RtMsgHdr.from_buffer_copy(data) + if hdr.rtm_seq == seq: + break + return data + + def read_message(self) -> bytes: + data = self.read_data() + return self.parse_message(data) diff --git a/tests/atf_python/sys/net/tools.py b/tests/atf_python/sys/net/tools.py new file mode 100644 index 000000000000..9f44872c2c37 --- /dev/null +++ b/tests/atf_python/sys/net/tools.py @@ -0,0 +1,33 @@ +#!/usr/local/bin/python3 +import json +import os +import socket +import time +from ctypes import cdll +from ctypes import get_errno +from ctypes.util import find_library +from typing import List +from typing import Optional + + +class ToolsHelper(object): + NETSTAT_PATH = "/usr/bin/netstat" + + @classmethod + def get_output(cls, cmd: str, verbose=False) -> str: + if verbose: + print("run: '{}'".format(cmd)) + return os.popen(cmd).read() + + @classmethod + def get_routes(cls, family: str, fibnum: int = 0): + family_key = {"inet": "-4", "inet6": "-6"}.get(family) + out = cls.get_output( + "{} {} -rn -F {} --libxo json".format(cls.NETSTAT_PATH, family_key, fibnum) + ) + js = json.loads(out) + js = js["statistics"]["route-information"]["route-table"]["rt-family"] + if js: + return js[0]["rt-entry"] + else: + return [] diff --git a/tests/atf_python/sys/net/vnet.py b/tests/atf_python/sys/net/vnet.py new file mode 100644 index 000000000000..0957364f627c --- /dev/null +++ b/tests/atf_python/sys/net/vnet.py @@ -0,0 +1,203 @@ +#!/usr/local/bin/python3 +import os +import socket +import time +from ctypes import cdll +from ctypes import get_errno +from ctypes.util import find_library +from typing import List +from typing import Optional + + +def run_cmd(cmd: str) -> str: + print("run: '{}'".format(cmd)) + return os.popen(cmd).read() + + +class VnetInterface(object): + INTERFACES_FNAME = "created_interfaces.lst" + + # defines from net/if_types.h + IFT_LOOP = 0x18 + IFT_ETHER = 0x06 + + def __init__(self, iface_name: str): + self.name = iface_name + self.vnet_name = "" + self.jailed = False + if iface_name.startswith("lo"): + self.iftype = self.IFT_LOOP + else: + self.iftype = self.IFT_ETHER + + @property + def ifindex(self): + return socket.if_nametoindex(self.name) + + def set_vnet(self, vnet_name: str): + self.vnet_name = vnet_name + + def set_jailed(self, jailed: bool): + self.jailed = jailed + + def run_cmd(self, cmd): + if self.vnet_name and not self.jailed: + cmd = "jexec {} {}".format(self.vnet_name, cmd) + run_cmd(cmd) + + @staticmethod + def file_append_line(line): + with open(VnetInterface.INTERFACES_FNAME, "a") as f: + f.write(line + "\n") + + @classmethod + def create_iface(cls, iface_name: str): + name = run_cmd("/sbin/ifconfig {} create".format(iface_name)).rstrip() + if not name: + raise Exception("Unable to create iface {}".format(iface_name)) + cls.file_append_line(name) + if name.startswith("epair"): + cls.file_append_line(name[:-1] + "b") + return cls(name) + + @staticmethod + def cleanup_ifaces(): + try: + with open(VnetInterface.INTERFACES_FNAME, "r") as f: + for line in f: + run_cmd("/sbin/ifconfig {} destroy".format(line.strip())) + os.unlink(VnetInterface.INTERFACES_FNAME) + except Exception: + pass + + def setup_addr(self, addr: str): + if ":" in addr: + family = "inet6" + else: + family = "inet" + cmd = "/sbin/ifconfig {} {} {}".format(self.name, family, addr) + self.run_cmd(cmd) + + def delete_addr(self, addr: str): + if ":" in addr: + cmd = "/sbin/ifconfig {} inet6 {} delete".format(self.name, addr) + else: + cmd = "/sbin/ifconfig {} -alias {}".format(self.name, addr) + self.run_cmd(cmd) + + def turn_up(self): + cmd = "/sbin/ifconfig {} up".format(self.name) + self.run_cmd(cmd) + + def enable_ipv6(self): + cmd = "/usr/sbin/ndp -i {} -disabled".format(self.name) + self.run_cmd(cmd) + + +class VnetInstance(object): + JAILS_FNAME = "created_jails.lst" + + def __init__(self, vnet_name: str, jid: int, ifaces: List[VnetInterface]): + self.name = vnet_name + self.jid = jid + self.ifaces = ifaces + for iface in ifaces: + iface.set_vnet(vnet_name) + iface.set_jailed(True) + + def run_vnet_cmd(self, cmd): + if self.vnet_name: + cmd = "jexec {} {}".format(self.vnet_name, cmd) + return run_cmd(cmd) + + @staticmethod + def wait_interface(vnet_name: str, iface_name: str): + cmd = "jexec {} /sbin/ifconfig -l".format(vnet_name) + for i in range(50): + ifaces = run_cmd(cmd).strip().split(" ") + if iface_name in ifaces: + return True + time.sleep(0.1) + return False + + @staticmethod + def file_append_line(line): + with open(VnetInstance.JAILS_FNAME, "a") as f: + f.write(line + "\n") + + @staticmethod + def cleanup_vnets(): + try: + with open(VnetInstance.JAILS_FNAME) as f: + for line in f: + run_cmd("/usr/sbin/jail -r {}".format(line.strip())) + os.unlink(VnetInstance.JAILS_FNAME) + except Exception: + pass + + @classmethod + def create_with_interfaces(cls, vnet_name: str, ifaces: List[VnetInterface]): + iface_cmds = " ".join(["vnet.interface={}".format(i.name) for i in ifaces]) + cmd = "/usr/sbin/jail -i -c name={} persist vnet {}".format( + vnet_name, iface_cmds + ) + jid_str = run_cmd(cmd) + jid = int(jid_str) + if jid <= 0: + raise Exception("Jail creation failed, output: {}".format(jid)) + cls.file_append_line(vnet_name) + + for iface in ifaces: + if cls.wait_interface(vnet_name, iface.name): + continue + raise Exception( + "Interface {} has not appeared in vnet {}".format(iface.name, vnet_name) + ) + return cls(vnet_name, jid, ifaces) + + @staticmethod + def attach_jid(jid: int): + _path: Optional[str] = find_library("c") + if _path is None: + raise Exception("libc not found") + path: str = _path + libc = cdll.LoadLibrary(path) + if libc.jail_attach(jid) != 0: + raise Exception("jail_attach() failed: errno {}".format(get_errno())) + + def attach(self): + self.attach_jid(self.jid) + + +class SingleVnetTestTemplate(object): + num_epairs = 1 + IPV6_PREFIXES: List[str] = [] + IPV4_PREFIXES: List[str] = [] + + def setup_method(self, method): + test_name = method.__name__ + vnet_name = "jail_{}".format(test_name) + ifaces = [] + for i in range(self.num_epairs): + ifaces.append(VnetInterface.create_iface("epair")) + self.vnet = VnetInstance.create_with_interfaces(vnet_name, ifaces) + self.vnet.attach() + for i, addr in enumerate(self.IPV6_PREFIXES): + if addr: + iface = self.vnet.ifaces[i] + iface.turn_up() + iface.enable_ipv6() + iface.setup_addr(addr) + for i, addr in enumerate(self.IPV4_PREFIXES): + if addr: + iface = self.vnet.ifaces[i] + iface.turn_up() + iface.setup_addr(addr) + + def cleanup(self, nodeid: str): + print("==== vnet cleanup ===") + VnetInstance.cleanup_vnets() + VnetInterface.cleanup_ifaces() + + def run_cmd(self, cmd: str) -> str: + return os.popen(cmd).read() |
