diff options
Diffstat (limited to 'tests/atf_python')
31 files changed, 5783 insertions, 0 deletions
diff --git a/tests/atf_python/Makefile b/tests/atf_python/Makefile new file mode 100644 index 000000000000..6b7c39948b0a --- /dev/null +++ b/tests/atf_python/Makefile @@ -0,0 +1,14 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +PACKAGE= tests + +FILES= __init__.py atf_pytest.py ktest.py utils.py +SUBDIR= sys + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python + + +.include <bsd.prog.mk> diff --git a/tests/atf_python/__init__.py b/tests/atf_python/__init__.py new file mode 100644 index 000000000000..6d5ec22ef054 --- /dev/null +++ b/tests/atf_python/__init__.py @@ -0,0 +1,4 @@ +import pytest + +pytest.register_assert_rewrite("atf_python.sys.net.rtsock") +pytest.register_assert_rewrite("atf_python.sys.net.vnet") diff --git a/tests/atf_python/atf_pytest.py b/tests/atf_python/atf_pytest.py new file mode 100644 index 000000000000..02ed502ace67 --- /dev/null +++ b/tests/atf_python/atf_pytest.py @@ -0,0 +1,298 @@ +import types +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Tuple + +from atf_python.ktest import generate_ktests +from atf_python.utils import nodeid_to_method_name + +import pytest +import os + + +class ATFCleanupItem(pytest.Item): + def runtest(self): + """Runs cleanup procedure for the test instead of the test itself""" + instance = self.parent.cls() + cleanup_name = "cleanup_{}".format(nodeid_to_method_name(self.nodeid)) + if hasattr(instance, cleanup_name): + cleanup = getattr(instance, cleanup_name) + cleanup(self.nodeid) + elif hasattr(instance, "cleanup"): + instance.cleanup(self.nodeid) + + def setup_method_noop(self, method): + """Overrides runtest setup method""" + pass + + def teardown_method_noop(self, method): + """Overrides runtest teardown method""" + pass + + +class ATFTestObj(object): + def __init__(self, obj, has_cleanup): + # Use nodeid without name to properly name class-derived tests + self.ident = obj.nodeid.split("::", 1)[1] + self.description = self._get_test_description(obj) + self.has_cleanup = has_cleanup + self.obj = obj + + def _get_test_description(self, obj): + """Returns first non-empty line from func docstring or func name""" + if getattr(obj, "descr", None) is not None: + return getattr(obj, "descr") + docstr = obj.function.__doc__ + if docstr: + for line in docstr.split("\n"): + if line: + return line + return obj.name + + @staticmethod + def _convert_user_mark(mark, obj, ret: Dict): + username = mark.args[0] + if username == "unprivileged": + # Special unprivileged user requested. + # First, require the unprivileged-user config option presence + key = "require.config" + if key not in ret: + ret[key] = "unprivileged_user" + else: + ret[key] = "{} {}".format(ret[key], "unprivileged_user") + # Check if the framework requires root + test_cls = ATFHandler.get_test_class(obj) + if test_cls and getattr(test_cls, "NEED_ROOT", False): + # Yes, so we ask kyua to run us under root instead + # It is up to the implementation to switch back to the desired + # user + ret["require.user"] = "root" + else: + ret["require.user"] = username + + def _convert_marks(self, obj) -> Dict[str, Any]: + wj_func = lambda x: " ".join(x) # noqa: E731 + _map: Dict[str, Dict] = { + "require_user": {"handler": self._convert_user_mark}, + "require_arch": {"name": "require.arch", "fmt": wj_func}, + "require_diskspace": {"name": "require.diskspace"}, + "require_files": {"name": "require.files", "fmt": wj_func}, + "require_machine": {"name": "require.machine", "fmt": wj_func}, + "require_memory": {"name": "require.memory"}, + "require_progs": {"name": "require.progs", "fmt": wj_func}, + "timeout": {}, + } + ret = {} + for mark in obj.iter_markers(): + if mark.name in _map: + if "handler" in _map[mark.name]: + _map[mark.name]["handler"](mark, obj, ret) + continue + name = _map[mark.name].get("name", mark.name) + if "fmt" in _map[mark.name]: + val = _map[mark.name]["fmt"](mark.args[0]) + else: + val = mark.args[0] + ret[name] = val + return ret + + def as_lines(self) -> List[str]: + """Output test definition in ATF-specific format""" + ret = [] + ret.append("ident: {}".format(self.ident)) + ret.append("descr: {}".format(self._get_test_description(self.obj))) + if self.has_cleanup: + ret.append("has.cleanup: true") + for key, value in self._convert_marks(self.obj).items(): + ret.append("{}: {}".format(key, value)) + return ret + + +class ATFHandler(object): + class ReportState(NamedTuple): + state: str + reason: str + + def __init__(self, report_file_name: Optional[str]): + self._tests_state_map: Dict[str, ReportStatus] = {} + self._report_file_name = report_file_name + self._report_file_handle = None + + def setup_configure(self): + fname = self._report_file_name + if fname: + self._report_file_handle = open(fname, mode="w") + + def setup_method_pre(self, item): + """Called before actually running the test setup_method""" + # Check if we need to manually drop the privileges + for mark in item.iter_markers(): + if mark.name == "require_user": + cls = self.get_test_class(item) + cls.TARGET_USER = mark.args[0] + break + + def override_runtest(self, obj): + # Override basic runtest command + obj.runtest = types.MethodType(ATFCleanupItem.runtest, obj) + # Override class setup/teardown + obj.parent.cls.setup_method = ATFCleanupItem.setup_method_noop + obj.parent.cls.teardown_method = ATFCleanupItem.teardown_method_noop + + @staticmethod + def get_test_class(obj): + if hasattr(obj, "parent") and obj.parent is not None: + if hasattr(obj.parent, "cls"): + return obj.parent.cls + + def has_object_cleanup(self, obj): + cls = self.get_test_class(obj) + if cls is not None: + method_name = nodeid_to_method_name(obj.nodeid) + cleanup_name = "cleanup_{}".format(method_name) + if hasattr(cls, "cleanup") or hasattr(cls, cleanup_name): + return True + return False + + def _generate_test_cleanups(self, items): + new_items = [] + for obj in items: + if self.has_object_cleanup(obj): + self.override_runtest(obj) + new_items.append(obj) + items.clear() + items.extend(new_items) + + def expand_tests(self, collector, name, obj): + return generate_ktests(collector, name, obj) + + def modify_tests(self, items, config): + if config.option.atf_cleanup: + self._generate_test_cleanups(items) + + def list_tests(self, tests: List[str]): + print('Content-Type: application/X-atf-tp; version="1"') + print() + for test_obj in tests: + has_cleanup = self.has_object_cleanup(test_obj) + atf_test = ATFTestObj(test_obj, has_cleanup) + for line in atf_test.as_lines(): + print(line) + print() + + def set_report_state(self, test_name: str, state: str, reason: str): + self._tests_state_map[test_name] = self.ReportState(state, reason) + + def _extract_report_reason(self, report): + data = report.longrepr + if data is None: + return None + if isinstance(data, Tuple): + # ('/path/to/test.py', 23, 'Skipped: unable to test') + reason = data[2] + for prefix in "Skipped: ": + if reason.startswith(prefix): + reason = reason[len(prefix):] + return reason + else: + # string/ traceback / exception report. Capture the last line + return str(data).split("\n")[-1] + return None + + def add_report(self, report): + # MAP pytest report state to the atf-desired state + # + # ATF test states: + # (1) expected_death, (2) expected_exit, (3) expected_failure + # (4) expected_signal, (5) expected_timeout, (6) passed + # (7) skipped, (8) failed + # + # Note that ATF don't have the concept of "soft xfail" - xpass + # is a failure. It also calls teardown routine in a separate + # process, thus teardown states (pytest-only) are handled as + # body continuation. + + # (stage, state, wasxfail) + + # Just a passing test: WANT: passed + # GOT: (setup, passed, F), (call, passed, F), (teardown, passed, F) + # + # Failing body test: WHAT: failed + # GOT: (setup, passed, F), (call, failed, F), (teardown, passed, F) + # + # pytest.skip test decorator: WANT: skipped + # GOT: (setup,skipped, False), (teardown, passed, False) + # + # pytest.skip call inside test function: WANT: skipped + # GOT: (setup, passed, F), (call, skipped, F), (teardown,passed, F) + # + # mark.xfail decorator+pytest.xfail: WANT: expected_failure + # GOT: (setup, passed, F), (call, skipped, T), (teardown, passed, F) + # + # mark.xfail decorator+pass: WANT: failed + # GOT: (setup, passed, F), (call, passed, T), (teardown, passed, F) + + test_name = report.location[2] + stage = report.when + state = report.outcome + reason = self._extract_report_reason(report) + + # We don't care about strict xfail - it gets translated to False + + if stage == "setup": + if state in ("skipped", "failed"): + # failed init -> failed test, skipped setup -> xskip + # for the whole test + self.set_report_state(test_name, state, reason) + elif stage == "call": + # "call" stage shouldn't matter if setup failed + if test_name in self._tests_state_map: + if self._tests_state_map[test_name].state == "failed": + return + if state == "failed": + # Record failure & override "skipped" state + self.set_report_state(test_name, state, reason) + elif state == "skipped": + if hasattr(report, "wasxfail"): + # xfail() called in the test body + state = "expected_failure" + else: + # skip inside the body + pass + self.set_report_state(test_name, state, reason) + elif state == "passed": + if hasattr(report, "wasxfail"): + # the test was expected to fail but didn't + # mark as hard failure + state = "failed" + self.set_report_state(test_name, state, reason) + elif stage == "teardown": + if state == "failed": + # teardown should be empty, as the cleanup + # procedures should be implemented as a separate + # function/method, so mark teardown failure as + # global failure + self.set_report_state(test_name, state, reason) + + def write_report(self): + if self._report_file_handle is None: + return + if self._tests_state_map: + # If we're executing in ATF mode, there has to be just one test + # Anyway, deterministically pick the first one + first_test_name = next(iter(self._tests_state_map)) + test = self._tests_state_map[first_test_name] + if test.state == "passed": + line = test.state + else: + line = "{}: {}".format(test.state, test.reason) + print(line, file=self._report_file_handle) + self._report_file_handle.close() + + @staticmethod + def get_atf_vars() -> Dict[str, str]: + px = "_ATF_VAR_" + return {k[len(px):]: v for k, v in os.environ.items() if k.startswith(px)} diff --git a/tests/atf_python/ktest.py b/tests/atf_python/ktest.py new file mode 100644 index 000000000000..a18f47d1dd06 --- /dev/null +++ b/tests/atf_python/ktest.py @@ -0,0 +1,173 @@ +import logging +import time +from typing import NamedTuple + +import pytest +from atf_python.sys.netlink.attrs import NlAttrNested +from atf_python.sys.netlink.attrs import NlAttrStr +from atf_python.sys.netlink.netlink import NetlinkMultipartIterator +from atf_python.sys.netlink.netlink import NlHelper +from atf_python.sys.netlink.netlink import Nlsock +from atf_python.sys.netlink.netlink_generic import KtestAttrType +from atf_python.sys.netlink.netlink_generic import KtestInfoMessage +from atf_python.sys.netlink.netlink_generic import KtestLogMsgType +from atf_python.sys.netlink.netlink_generic import KtestMsgAttrType +from atf_python.sys.netlink.netlink_generic import KtestMsgType +from atf_python.sys.netlink.netlink_generic import timespec +from atf_python.sys.netlink.utils import NlConst +from atf_python.utils import BaseTest +from atf_python.utils import libc +from atf_python.utils import nodeid_to_method_name + + +datefmt = "%H:%M:%S" +fmt = "%(asctime)s.%(msecs)03d %(filename)s:%(funcName)s:%(lineno)d %(message)s" +logging.basicConfig(level=logging.DEBUG, format=fmt, datefmt=datefmt) +logger = logging.getLogger("ktest") + + +NETLINK_FAMILY = "ktest" + + +class KtestItem(pytest.Item): + def __init__(self, *, descr, kcls, **kwargs): + super().__init__(**kwargs) + self.descr = descr + self._kcls = kcls + + def runtest(self): + self._kcls().runtest() + + +class KtestCollector(pytest.Class): + def collect(self): + obj = self.obj + exclude_names = set([n for n in dir(obj) if not n.startswith("_")]) + + autoload = obj.KTEST_MODULE_AUTOLOAD + module_name = obj.KTEST_MODULE_NAME + loader = KtestLoader(module_name, autoload) + ktests = loader.load_ktests() + if not ktests: + return + + orig = pytest.Class.from_parent(self.parent, name=self.name, obj=obj) + for py_test in orig.collect(): + yield py_test + + for ktest in ktests: + name = ktest["name"] + descr = ktest["desc"] + if name in exclude_names: + continue + yield KtestItem.from_parent(self, name=name, descr=descr, kcls=obj) + + +class KtestLoader(object): + def __init__(self, module_name: str, autoload: bool): + self.module_name = module_name + self.autoload = autoload + self.helper = NlHelper() + self.nlsock = Nlsock(NlConst.NETLINK_GENERIC, self.helper) + self.family_id = self._get_family_id() + + def _get_family_id(self): + try: + family_id = self.nlsock.get_genl_family_id(NETLINK_FAMILY) + except ValueError: + if self.autoload: + libc.kldload(self.module_name) + family_id = self.nlsock.get_genl_family_id(NETLINK_FAMILY) + else: + raise + return family_id + + def _load_ktests(self): + msg = KtestInfoMessage(self.helper, self.family_id, KtestMsgType.KTEST_CMD_LIST) + msg.set_request() + msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_MOD_NAME, self.module_name)) + self.nlsock.write_message(msg, verbose=False) + nlmsg_seq = msg.nl_hdr.nlmsg_seq + + ret = [] + for rx_msg in NetlinkMultipartIterator(self.nlsock, nlmsg_seq, self.family_id): + # rx_msg.print_message() + tst = { + "mod_name": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_MOD_NAME).text, + "name": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_TEST_NAME).text, + "desc": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_TEST_DESCR).text, + } + ret.append(tst) + return ret + + def load_ktests(self): + ret = self._load_ktests() + if not ret and self.autoload: + libc.kldload(self.module_name) + ret = self._load_ktests() + return ret + + +def generate_ktests(collector, name, obj): + if getattr(obj, "KTEST_MODULE_NAME", None) is not None: + return KtestCollector.from_parent(collector, name=name, obj=obj) + return None + + +class BaseKernelTest(BaseTest): + KTEST_MODULE_AUTOLOAD = True + KTEST_MODULE_NAME = None + + def _get_record_time(self, msg) -> float: + timespec = msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_TS).ts + epoch_ktime = timespec.tv_sec * 1.0 + timespec.tv_nsec * 1.0 / 1000000000 + if not hasattr(self, "_start_epoch"): + self._start_ktime = epoch_ktime + self._start_time = time.time() + epoch_time = self._start_time + else: + epoch_time = time.time() - self._start_time + epoch_ktime + return epoch_time + + def _log_message(self, msg): + # Convert syslog-type l + syslog_level = msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_LEVEL).u8 + if syslog_level <= 6: + loglevel = logging.INFO + else: + loglevel = logging.DEBUG + rec = logging.LogRecord( + self.KTEST_MODULE_NAME, + loglevel, + msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_FILE).text, + msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_LINE).u32, + "%s", + (msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_TEXT).text), + None, + msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_FUNC).text, + None, + ) + rec.created = self._get_record_time(msg) + logger.handle(rec) + + def _runtest_name(self, test_name: str, test_data): + module_name = self.KTEST_MODULE_NAME + # print("Running kernel test {} for module {}".format(test_name, module_name)) + helper = NlHelper() + nlsock = Nlsock(NlConst.NETLINK_GENERIC, helper) + family_id = nlsock.get_genl_family_id(NETLINK_FAMILY) + msg = KtestInfoMessage(helper, family_id, KtestMsgType.KTEST_CMD_RUN) + msg.set_request() + msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_MOD_NAME, module_name)) + msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_TEST_NAME, test_name)) + if test_data is not None: + msg.add_nla(NlAttrNested(KtestAttrType.KTEST_ATTR_TEST_META, test_data)) + nlsock.write_message(msg, verbose=False) + + for log_msg in NetlinkMultipartIterator( + nlsock, msg.nl_hdr.nlmsg_seq, family_id + ): + self._log_message(log_msg) + + def runtest(self, test_data=None): + self._runtest_name(nodeid_to_method_name(self.test_id), test_data) diff --git a/tests/atf_python/sys/Makefile b/tests/atf_python/sys/Makefile new file mode 100644 index 000000000000..a5a1a532104d --- /dev/null +++ b/tests/atf_python/sys/Makefile @@ -0,0 +1,12 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +PACKAGE=tests +FILES= __init__.py +SUBDIR= net netlink netpfil + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python/sys + +.include <bsd.prog.mk> diff --git a/tests/atf_python/sys/__init__.py b/tests/atf_python/sys/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 --- /dev/null +++ b/tests/atf_python/sys/__init__.py diff --git a/tests/atf_python/sys/net/Makefile b/tests/atf_python/sys/net/Makefile new file mode 100644 index 000000000000..70d5b1a3284b --- /dev/null +++ b/tests/atf_python/sys/net/Makefile @@ -0,0 +1,11 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +PACKAGE=tests +FILES= __init__.py rtsock.py tools.py vnet.py + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python/sys/net + +.include <bsd.prog.mk> diff --git a/tests/atf_python/sys/net/__init__.py b/tests/atf_python/sys/net/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 --- /dev/null +++ b/tests/atf_python/sys/net/__init__.py diff --git a/tests/atf_python/sys/net/rtsock.py b/tests/atf_python/sys/net/rtsock.py new file mode 100755 index 000000000000..788e863f8b28 --- /dev/null +++ b/tests/atf_python/sys/net/rtsock.py @@ -0,0 +1,604 @@ +#!/usr/local/bin/python3 +import os +import socket +import struct +import sys +from ctypes import c_byte +from ctypes import c_char +from ctypes import c_int +from ctypes import c_long +from ctypes import c_uint32 +from ctypes import c_ulong +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + + +def roundup2(val: int, num: int) -> int: + if val % num: + return (val | (num - 1)) + 1 + else: + return val + + +class RtSockException(OSError): + pass + + +class RtConst: + RTM_VERSION = 5 + ALIGN = sizeof(c_long) + + AF_INET = socket.AF_INET + AF_INET6 = socket.AF_INET6 + AF_LINK = socket.AF_LINK + + RTA_DST = 0x1 + RTA_GATEWAY = 0x2 + RTA_NETMASK = 0x4 + RTA_GENMASK = 0x8 + RTA_IFP = 0x10 + RTA_IFA = 0x20 + RTA_AUTHOR = 0x40 + RTA_BRD = 0x80 + + RTM_ADD = 1 + RTM_DELETE = 2 + RTM_CHANGE = 3 + RTM_GET = 4 + + RTF_UP = 0x1 + RTF_GATEWAY = 0x2 + RTF_HOST = 0x4 + RTF_REJECT = 0x8 + RTF_DYNAMIC = 0x10 + RTF_MODIFIED = 0x20 + RTF_DONE = 0x40 + RTF_XRESOLVE = 0x200 + RTF_LLINFO = 0x400 + RTF_LLDATA = 0x400 + RTF_STATIC = 0x800 + RTF_BLACKHOLE = 0x1000 + RTF_PROTO2 = 0x4000 + RTF_PROTO1 = 0x8000 + RTF_PROTO3 = 0x40000 + RTF_FIXEDMTU = 0x80000 + RTF_PINNED = 0x100000 + RTF_LOCAL = 0x200000 + RTF_BROADCAST = 0x400000 + RTF_MULTICAST = 0x800000 + RTF_STICKY = 0x10000000 + RTF_RNH_LOCKED = 0x40000000 + RTF_GWFLAG_COMPAT = 0x80000000 + + RTV_MTU = 0x1 + RTV_HOPCOUNT = 0x2 + RTV_EXPIRE = 0x4 + RTV_RPIPE = 0x8 + RTV_SPIPE = 0x10 + RTV_SSTHRESH = 0x20 + RTV_RTT = 0x40 + RTV_RTTVAR = 0x80 + RTV_WEIGHT = 0x100 + + @staticmethod + def get_props(prefix: str) -> List[str]: + return [n for n in dir(RtConst) if n.startswith(prefix)] + + @staticmethod + def get_name(prefix: str, value: int) -> str: + props = RtConst.get_props(prefix) + for prop in props: + if getattr(RtConst, prop) == value: + return prop + return "U:{}:{}".format(prefix, value) + + @staticmethod + def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]: + props = RtConst.get_props(prefix) + propmap = {getattr(RtConst, prop): prop for prop in props} + v = 1 + ret = {} + while value: + if v & value: + if v in propmap: + ret[v] = propmap[v] + else: + ret[v] = hex(v) + value -= v + v *= 2 + return ret + + @staticmethod + def get_bitmask_str(prefix: str, value: int) -> str: + bmap = RtConst.get_bitmask_map(prefix, value) + return ",".join([v for k, v in bmap.items()]) + + +class RtMetrics(Structure): + _fields_ = [ + ("rmx_locks", c_ulong), + ("rmx_mtu", c_ulong), + ("rmx_hopcount", c_ulong), + ("rmx_expire", c_ulong), + ("rmx_recvpipe", c_ulong), + ("rmx_sendpipe", c_ulong), + ("rmx_ssthresh", c_ulong), + ("rmx_rtt", c_ulong), + ("rmx_rttvar", c_ulong), + ("rmx_pksent", c_ulong), + ("rmx_weight", c_ulong), + ("rmx_nhidx", c_ulong), + ("rmx_filler", c_ulong * 2), + ] + + +class RtMsgHdr(Structure): + _fields_ = [ + ("rtm_msglen", c_ushort), + ("rtm_version", c_byte), + ("rtm_type", c_byte), + ("rtm_index", c_ushort), + ("_rtm_spare1", c_ushort), + ("rtm_flags", c_int), + ("rtm_addrs", c_int), + ("rtm_pid", c_int), + ("rtm_seq", c_int), + ("rtm_errno", c_int), + ("rtm_fmask", c_int), + ("rtm_inits", c_ulong), + ("rtm_rmx", RtMetrics), + ] + + +class SockaddrIn(Structure): + _fields_ = [ + ("sin_len", c_byte), + ("sin_family", c_byte), + ("sin_port", c_ushort), + ("sin_addr", c_uint32), + ("sin_zero", c_char * 8), + ] + + +class SockaddrIn6(Structure): + _fields_ = [ + ("sin6_len", c_byte), + ("sin6_family", c_byte), + ("sin6_port", c_ushort), + ("sin6_flowinfo", c_uint32), + ("sin6_addr", c_byte * 16), + ("sin6_scope_id", c_uint32), + ] + + +class SockaddrDl(Structure): + _fields_ = [ + ("sdl_len", c_byte), + ("sdl_family", c_byte), + ("sdl_index", c_ushort), + ("sdl_type", c_byte), + ("sdl_nlen", c_byte), + ("sdl_alen", c_byte), + ("sdl_slen", c_byte), + ("sdl_data", c_byte * 8), + ] + + +class SaHelper(object): + @staticmethod + def is_ipv6(ip: str) -> bool: + return ":" in ip + + @staticmethod + def ip_sa(ip: str, scopeid: int = 0) -> bytes: + if SaHelper.is_ipv6(ip): + return SaHelper.ip6_sa(ip, scopeid) + else: + return SaHelper.ip4_sa(ip) + + @staticmethod + def ip4_sa(ip: str) -> bytes: + addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder) + sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int) + return bytes(sin) + + @staticmethod + def ip6_sa(ip6: str, scopeid: int) -> bytes: + addr_bytes = (c_byte * 16)() + for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)): + addr_bytes[i] = b + sin6 = SockaddrIn6( + sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid + ) + return bytes(sin6) + + @staticmethod + def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes: + sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype) + return bytes(sa) + + @staticmethod + def pxlen4_sa(pxlen: int) -> bytes: + return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen)) + + @staticmethod + def pxlen_to_ip4(pxlen: int) -> str: + if pxlen == 32: + return "255.255.255.255" + else: + addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1) + addr_bytes = struct.pack("!I", addr) + return socket.inet_ntop(socket.AF_INET, addr_bytes) + + @staticmethod + def pxlen6_sa(pxlen: int) -> bytes: + return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen)) + + @staticmethod + def pxlen_to_ip6(pxlen: int) -> str: + ip6_b = [0] * 16 + start = 0 + while pxlen > 8: + ip6_b[start] = 0xFF + pxlen -= 8 + start += 1 + ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1) + return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b)) + + @staticmethod + def print_sa_inet(sa: bytes): + if len(sa) < 8: + raise RtSockException("IPv4 sa size too small: {}".format(len(sa))) + addr = socket.inet_ntop(socket.AF_INET, sa[4:8]) + return "{}".format(addr) + + @staticmethod + def print_sa_inet6(sa: bytes): + if len(sa) < sizeof(SockaddrIn6): + raise RtSockException("IPv6 sa size too small: {}".format(len(sa))) + addr = socket.inet_ntop(socket.AF_INET6, sa[8:24]) + scopeid = struct.unpack(">I", sa[24:28])[0] + return "{} scopeid {}".format(addr, scopeid) + + @staticmethod + def print_sa_link(sa: bytes, hd: Optional[bool] = True): + if len(sa) < sizeof(SockaddrDl): + raise RtSockException("LINK sa size too small: {}".format(len(sa))) + sdl = SockaddrDl.from_buffer_copy(sa) + if sdl.sdl_index: + ifindex = "link#{} ".format(sdl.sdl_index) + else: + ifindex = "" + if sdl.sdl_nlen: + iface_offset = 8 + if sdl.sdl_nlen + iface_offset > len(sa): + raise RtSockException( + "LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa)) + ) + ifname = "ifname:{} ".format( + bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen]) + ) + else: + ifname = "" + return "{}{}".format(ifindex, ifname) + + @staticmethod + def print_sa_unknown(sa: bytes): + return "unknown_type:{}".format(sa[1]) + + @classmethod + def print_sa(cls, sa: bytes, hd: Optional[bool] = False): + if sa[0] != len(sa): + raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa))) + + if len(sa) < 2: + raise Exception( + "sa type {} too short: {}".format( + RtConst.get_name("AF_", sa[1]), len(sa) + ) + ) + + if sa[1] == socket.AF_INET: + text = cls.print_sa_inet(sa) + elif sa[1] == socket.AF_INET6: + text = cls.print_sa_inet6(sa) + elif sa[1] == socket.AF_LINK: + text = cls.print_sa_link(sa) + else: + text = cls.print_sa_unknown(sa) + if hd: + dump = " [{!r}]".format(sa) + else: + dump = "" + return "{}{}".format(text, dump) + + +class BaseRtsockMessage(object): + def __init__(self, rtm_type): + self.rtm_type = rtm_type + self.sa = SaHelper() + + @staticmethod + def print_rtm_type(rtm_type): + return RtConst.get_name("RTM_", rtm_type) + + @property + def rtm_type_str(self): + return self.print_rtm_type(self.rtm_type) + + +class RtsockRtMessage(BaseRtsockMessage): + messages = [ + RtConst.RTM_ADD, + RtConst.RTM_DELETE, + RtConst.RTM_CHANGE, + RtConst.RTM_GET, + ] + + def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None): + super().__init__(rtm_type) + self.rtm_flags = 0 + self.rtm_seq = rtm_seq + self._attrs = {} + self.rtm_errno = 0 + self.rtm_pid = 0 + self.rtm_inits = 0 + self.rtm_rmx = RtMetrics() + self._orig_data = None + if dst_sa: + self.add_sa_attr(RtConst.RTA_DST, dst_sa) + if mask_sa: + self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa) + + def add_sa_attr(self, attr_type, attr_bytes: bytes): + self._attrs[attr_type] = attr_bytes + + def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0): + if ":" in ip_addr: + self.add_ip6_attr(attr_type, ip_addr, scopeid) + else: + self.add_ip4_attr(attr_type, ip_addr) + + def add_ip4_attr(self, attr_type, ip: str): + self.add_sa_attr(attr_type, self.sa.ip_sa(ip)) + + def add_ip6_attr(self, attr_type, ip6: str, scopeid: int): + self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid)) + + def add_link_attr(self, attr_type, ifindex: Optional[int] = 0): + self.add_sa_attr(attr_type, self.sa.link_sa(ifindex)) + + def get_sa(self, attr_type) -> bytes: + return self._attrs.get(attr_type) + + def print_message(self): + # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC> + if self._orig_data: + rtm_len = len(self._orig_data) + else: + rtm_len = len(bytes(self)) + print( + "{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format( + self.rtm_type_str, + rtm_len, + self.rtm_pid, + self.rtm_seq, + self.rtm_errno, + RtConst.get_bitmask_str("RTF_", self.rtm_flags), + ) + ) + rtm_addrs = sum(list(self._attrs.keys())) + print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs))) + for attr in sorted(self._attrs.keys()): + sa_data = SaHelper.print_sa(self._attrs[attr]) + print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data)) + + def print_in_message(self): + print("vvvvvvvv IN vvvvvvvv") + self.print_message() + print() + + def verify_sa_inet(self, sa_data): + if len(sa_data) < 8: + raise Exception("IPv4 sa size too small: {}".format(sa_data)) + if sa_data[0] > len(sa_data): + raise Exception( + "IPv4 sin_len too big: {} vs sa size {}: {}".format( + sa_data[0], len(sa_data), sa_data + ) + ) + sin = SockaddrIn.from_buffer_copy(sa_data) + assert sin.sin_port == 0 + assert sin.sin_zero == [0] * 8 + + def compare_sa(self, sa_type, sa_data): + if len(sa_data) < 4: + sa_type_name = RtConst.get_name("RTA_", sa_type) + raise Exception( + "sa_len for type {} too short: {}".format(sa_type_name, len(sa_data)) + ) + our_sa = self._attrs[sa_type] + assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa) + assert len(sa_data) == len(our_sa) + assert sa_data == our_sa + + def verify(self, rtm_type: int, rtm_sa): + assert self.rtm_type_str == self.print_rtm_type(rtm_type) + assert self.rtm_errno == 0 + hdr = RtMsgHdr.from_buffer_copy(self._orig_data) + assert hdr._rtm_spare1 == 0 + for sa_type, sa_data in rtm_sa.items(): + if sa_type not in self._attrs: + sa_type_name = RtConst.get_name("RTA_", sa_type) + raise Exception("SA type {} not present".format(sa_type_name)) + self.compare_sa(sa_type, sa_data) + + @classmethod + def from_bytes(cls, data: bytes): + if len(data) < sizeof(RtMsgHdr): + raise Exception( + "messages size {} is less than expected {}".format( + len(data), sizeof(RtMsgHdr) + ) + ) + hdr = RtMsgHdr.from_buffer_copy(data) + + self = cls(hdr.rtm_type) + self.rtm_flags = hdr.rtm_flags + self.rtm_seq = hdr.rtm_seq + self.rtm_errno = hdr.rtm_errno + self.rtm_pid = hdr.rtm_pid + self.rtm_inits = hdr.rtm_inits + self.rtm_rmx = hdr.rtm_rmx + self._orig_data = data + + off = sizeof(RtMsgHdr) + v = 1 + addrs_mask = hdr.rtm_addrs + while addrs_mask: + if addrs_mask & v: + addrs_mask -= v + + if off + data[off] > len(data): + raise Exception( + "SA sizeof for {} > total message length: {}+{} > {}".format( + RtConst.get_name("RTA_", v), off, data[off], len(data) + ) + ) + self._attrs[v] = data[off : off + data[off]] + off += roundup2(data[off], RtConst.ALIGN) + v *= 2 + return self + + def __bytes__(self): + sz = sizeof(RtMsgHdr) + addrs_mask = 0 + for k, v in self._attrs.items(): + sz += roundup2(len(v), RtConst.ALIGN) + addrs_mask += k + hdr = RtMsgHdr( + rtm_msglen=sz, + rtm_version=RtConst.RTM_VERSION, + rtm_type=self.rtm_type, + rtm_flags=self.rtm_flags, + rtm_seq=self.rtm_seq, + rtm_addrs=addrs_mask, + rtm_inits=self.rtm_inits, + rtm_rmx=self.rtm_rmx, + ) + buf = bytearray(sz) + buf[0 : sizeof(RtMsgHdr)] = hdr + off = sizeof(RtMsgHdr) + for attr in sorted(self._attrs.keys()): + v = self._attrs[attr] + sa_len = len(v) + buf[off : off + sa_len] = v + off += roundup2(len(v), RtConst.ALIGN) + return bytes(buf) + + +class Rtsock: + def __init__(self): + self.socket = self._setup_rtsock() + self.rtm_seq = 1 + self.msgmap = self.build_msgmap() + + def build_msgmap(self): + classes = [RtsockRtMessage] + xmap = {} + for cls in classes: + for message in cls.messages: + xmap[message] = cls + return xmap + + def get_seq(self): + ret = self.rtm_seq + self.rtm_seq += 1 + return ret + + def get_weight(self, weight) -> int: + if weight: + return weight + else: + return 1 # RT_DEFAULT_WEIGHT + + def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]): + px = prefix.split("/") + addr_sa = SaHelper.ip_sa(px[0]) + if len(px) > 1: + pxlen = int(px[1]) + if SaHelper.is_ipv6(px[0]): + mask_sa = SaHelper.pxlen6_sa(pxlen) + else: + mask_sa = SaHelper.pxlen4_sa(pxlen) + else: + mask_sa = None + msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa) + if isinstance(gw, bytes): + msg.add_sa_attr(RtConst.RTA_GATEWAY, gw) + else: + # String + msg.add_ip_attr(RtConst.RTA_GATEWAY, gw) + return msg + + def new_rtm_add(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw) + + def new_rtm_del(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw) + + def new_rtm_change(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw) + + def _setup_rtsock(self) -> socket.socket: + s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC) + s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1) + return s + + def print_hd(self, data: bytes): + width = 16 + print("==========================================") + for chunk in [data[i : i + width] for i in range(0, len(data), width)]: + for b in chunk: + print("0x{:02X} ".format(b), end="") + print() + print() + + def write_message(self, msg): + print("vvvvvvvv OUT vvvvvvvv") + msg.print_message() + print() + msg_bytes = bytes(msg) + ret = os.write(self.socket.fileno(), msg_bytes) + if ret != -1: + assert ret == len(msg_bytes) + + def parse_message(self, data: bytes): + if len(data) < 4: + raise OSError("Short read from rtsock: {} bytes".format(len(data))) + rtm_type = data[4] + if rtm_type not in self.msgmap: + return None + + def write_data(self, data: bytes): + self.socket.send(data) + + def read_data(self, seq: Optional[int] = None) -> bytes: + while True: + data = self.socket.recv(4096) + if seq is None: + break + if len(data) > sizeof(RtMsgHdr): + hdr = RtMsgHdr.from_buffer_copy(data) + if hdr.rtm_seq == seq: + break + return data + + def read_message(self) -> bytes: + data = self.read_data() + return self.parse_message(data) diff --git a/tests/atf_python/sys/net/tools.py b/tests/atf_python/sys/net/tools.py new file mode 100644 index 000000000000..44bd74d8578f --- /dev/null +++ b/tests/atf_python/sys/net/tools.py @@ -0,0 +1,100 @@ +#!/usr/local/bin/python3 +import json +import os +import subprocess + + +class ToolsHelper(object): + NETSTAT_PATH = "/usr/bin/netstat" + IFCONFIG_PATH = "/sbin/ifconfig" + + @classmethod + def get_output(cls, cmd: str, verbose=False) -> str: + if verbose: + print("run: '{}'".format(cmd)) + return os.popen(cmd).read() + + @classmethod + def pf_rules(cls, rules, verbose=True): + pf_conf = "" + for r in rules: + pf_conf = pf_conf + r + "\n" + + if verbose: + print("Set rules:") + print(pf_conf) + + ps = subprocess.Popen("/sbin/pfctl -g -f -", shell=True, + stdin=subprocess.PIPE) + ps.communicate(bytes(pf_conf, 'utf-8')) + ret = ps.wait() + if ret != 0: + raise Exception("Failed to set pf rules %d" % ret) + + if verbose: + cls.print_output("/sbin/pfctl -sr") + + @classmethod + def print_output(cls, cmd: str, verbose=True): + if verbose: + print("======= {} =====".format(cmd)) + print(cls.get_output(cmd)) + if verbose: + print() + + @classmethod + def print_net_debug(cls): + cls.print_output("ifconfig") + cls.print_output("netstat -rnW") + + @classmethod + def set_sysctl(cls, oid, val): + cls.get_output("sysctl {}={}".format(oid, val)) + + @classmethod + def get_routes(cls, family: str, fibnum: int = 0): + family_key = {"inet": "-4", "inet6": "-6"}.get(family) + out = cls.get_output( + "{} {} -rnW -F {} --libxo json".format(cls.NETSTAT_PATH, family_key, fibnum) + ) + js = json.loads(out) + js = js["statistics"]["route-information"]["route-table"]["rt-family"] + if js: + return js[0]["rt-entry"] + else: + return [] + + @classmethod + def get_nhops(cls, family: str, fibnum: int = 0): + family_key = {"inet": "-4", "inet6": "-6"}.get(family) + out = cls.get_output( + "{} {} -onW -F {} --libxo json".format(cls.NETSTAT_PATH, family_key, fibnum) + ) + js = json.loads(out) + js = js["statistics"]["route-nhop-information"]["nhop-table"]["rt-family"] + if js: + return js[0]["nh-entry"] + else: + return [] + + @classmethod + def get_linklocals(cls): + ret = {} + ifname = None + ips = [] + for line in cls.get_output(cls.IFCONFIG_PATH).splitlines(): + if line[0].isalnum(): + if ifname: + ret[ifname] = ips + ips = [] + ifname = line.split(":")[0] + else: + words = line.split() + if words[0] == "inet6" and words[1].startswith("fe80"): + # inet6 fe80::1%lo0 prefixlen 64 scopeid 0x2 + ip = words[1].split("%")[0] + scopeid = int(words[words.index("scopeid") + 1], 16) + ips.append((ip, scopeid)) + if ifname: + ret[ifname] = ips + return ret diff --git a/tests/atf_python/sys/net/vnet.py b/tests/atf_python/sys/net/vnet.py new file mode 100644 index 000000000000..f75a3eaa693e --- /dev/null +++ b/tests/atf_python/sys/net/vnet.py @@ -0,0 +1,559 @@ +#!/usr/local/bin/python3 +import copy +import ipaddress +import os +import re +import socket +import sys +import time +from multiprocessing import connection +from multiprocessing import Pipe +from multiprocessing import Process +from typing import Dict +from typing import List +from typing import NamedTuple + +from atf_python.sys.net.tools import ToolsHelper +from atf_python.utils import BaseTest +from atf_python.utils import libc + + +def run_cmd(cmd: str, verbose=True) -> str: + if verbose: + print("run: '{}'".format(cmd)) + return os.popen(cmd).read() + + +def get_topology_id(test_id: str) -> str: + """ + Gets a unique topology id based on the pytest test_id. + "test_ip6_output.py::TestIP6Output::test_output6_pktinfo[ipandif]" -> + "TestIP6Output:test_output6_pktinfo[ipandif]" + """ + return ":".join(test_id.split("::")[-2:]) + + +def convert_test_name(test_name: str) -> str: + """Convert test name to a string that can be used in the file/jail names""" + ret = "" + for char in test_name: + if char.isalnum() or char in ("_", "-", ":"): + ret += char + elif char in ("["): + ret += "_" + return ret + + +class VnetInterface(object): + # defines from net/if_types.h + IFT_LOOP = 0x18 + IFT_ETHER = 0x06 + + def __init__(self, iface_alias: str, iface_name: str): + self.name = iface_name + self.alias = iface_alias + self.vnet_name = "" + self.jailed = False + self.addr_map: Dict[str, Dict] = {"inet6": {}, "inet": {}} + self.prefixes4: List[List[str]] = [] + self.prefixes6: List[List[str]] = [] + if iface_name.startswith("lo"): + self.iftype = self.IFT_LOOP + else: + self.iftype = self.IFT_ETHER + self.ether = ToolsHelper.get_output("/sbin/ifconfig %s ether | awk '/ether/ { print $2; }'" % iface_name).rstrip() + + @property + def ifindex(self): + return socket.if_nametoindex(self.name) + + @property + def first_ipv6(self): + d = self.addr_map["inet6"] + return d[next(iter(d))] + + @property + def first_ipv4(self): + d = self.addr_map["inet"] + return d[next(iter(d))] + + def set_vnet(self, vnet_name: str): + self.vnet_name = vnet_name + + def set_jailed(self, jailed: bool): + self.jailed = jailed + + def run_cmd(self, cmd, verbose=False): + if self.vnet_name and not self.jailed: + cmd = "/usr/sbin/jexec {} {}".format(self.vnet_name, cmd) + return run_cmd(cmd, verbose) + + @classmethod + def setup_loopback(cls, vnet_name: str): + lo = VnetInterface("", "lo0") + lo.set_vnet(vnet_name) + lo.setup_addr("127.0.0.1/8") + lo.turn_up() + + @classmethod + def create_iface(cls, alias_name: str, iface_name: str) -> List["VnetInterface"]: + name = run_cmd("/sbin/ifconfig {} create".format(iface_name)).rstrip() + if not name: + raise Exception("Unable to create iface {}".format(iface_name)) + if1 = cls(alias_name, name) + ret = [if1] + if name.startswith("epair"): + run_cmd("/sbin/ifconfig {} -txcsum -txcsum6".format(name)) + if2 = cls(alias_name, name[:-1] + "b") + if1.epairb = if2 + ret.append(if2); + return ret + + def setup_addr(self, _addr: str): + addr = ipaddress.ip_interface(_addr) + if addr.version == 6: + family = "inet6" + cmd = "/sbin/ifconfig {} {} {}".format(self.name, family, addr) + else: + family = "inet" + if self.addr_map[family]: + cmd = "/sbin/ifconfig {} alias {}".format(self.name, addr) + else: + cmd = "/sbin/ifconfig {} {} {}".format(self.name, family, addr) + self.run_cmd(cmd) + self.addr_map[family][str(addr.ip)] = addr + + def delete_addr(self, _addr: str): + addr = ipaddress.ip_address(_addr) + if addr.version == 6: + family = "inet6" + cmd = "/sbin/ifconfig {} inet6 {} delete".format(self.name, addr) + else: + family = "inet" + cmd = "/sbin/ifconfig {} -alias {}".format(self.name, addr) + self.run_cmd(cmd) + del self.addr_map[family][str(addr)] + + def turn_up(self): + cmd = "/sbin/ifconfig {} up".format(self.name) + self.run_cmd(cmd) + + def enable_ipv6(self): + cmd = "/usr/sbin/ndp -i {} -- -disabled".format(self.name) + self.run_cmd(cmd) + + def has_tentative(self) -> bool: + """True if an interface has some addresses in tenative state""" + cmd = "/sbin/ifconfig {} inet6".format(self.name) + out = self.run_cmd(cmd, verbose=False) + for line in out.splitlines(): + if "tentative" in line: + return True + return False + + +class IfaceFactory(object): + INTERFACES_FNAME = "created_ifaces.lst" + AUTODELETE_TYPES = ("epair", "gif", "gre", "lo", "tap", "tun") + + def __init__(self): + self.file_name = self.INTERFACES_FNAME + + def _register_iface(self, iface_name: str): + with open(self.file_name, "a") as f: + f.write(iface_name + "\n") + + def _list_ifaces(self) -> List[str]: + ret: List[str] = [] + try: + with open(self.file_name, "r") as f: + for line in f: + ret.append(line.strip()) + except OSError: + pass + return ret + + def create_iface(self, alias_name: str, iface_name: str) -> List[VnetInterface]: + ifaces = VnetInterface.create_iface(alias_name, iface_name) + for iface in ifaces: + if not self.is_autodeleted(iface.name): + self._register_iface(iface.name) + return ifaces + + @staticmethod + def is_autodeleted(iface_name: str) -> bool: + if iface_name == "lo0": + return False + iface_type = re.split(r"\d+", iface_name)[0] + return iface_type in IfaceFactory.AUTODELETE_TYPES + + def cleanup_vnet_interfaces(self, vnet_name: str) -> List[str]: + """Destroys""" + ifaces_lst = ToolsHelper.get_output( + "/usr/sbin/jexec {} /sbin/ifconfig -l".format(vnet_name) + ) + for iface_name in ifaces_lst.split(): + if not self.is_autodeleted(iface_name): + if iface_name not in self._list_ifaces(): + print("Skipping interface {}:{}".format(vnet_name, iface_name)) + continue + run_cmd( + "/usr/sbin/jexec {} /sbin/ifconfig {} destroy".format(vnet_name, iface_name) + ) + + def cleanup(self): + try: + os.unlink(self.INTERFACES_FNAME) + except OSError: + pass + + +class VnetInstance(object): + def __init__( + self, vnet_alias: str, vnet_name: str, jid: int, ifaces: List[VnetInterface] + ): + self.name = vnet_name + self.alias = vnet_alias # reference in the test topology + self.jid = jid + self.ifaces = ifaces + self.iface_alias_map = {} # iface.alias: iface + self.iface_map = {} # iface.name: iface + for iface in ifaces: + iface.set_vnet(vnet_name) + iface.set_jailed(True) + self.iface_alias_map[iface.alias] = iface + self.iface_map[iface.name] = iface + # Allow reference to interfce aliases as attributes + setattr(self, iface.alias, iface) + self.need_dad = False # Disable duplicate address detection by default + self.attached = False + self.pipe = None + self.subprocess = None + + def run_vnet_cmd(self, cmd, verbose=True): + if not self.attached: + cmd = "/usr/sbin/jexec {} {}".format(self.name, cmd) + return run_cmd(cmd, verbose) + + def disable_dad(self): + self.run_vnet_cmd("/sbin/sysctl net.inet6.ip6.dad_count=0") + + def set_pipe(self, pipe): + self.pipe = pipe + + def set_subprocess(self, p): + self.subprocess = p + + @staticmethod + def attach_jid(jid: int): + error_code = libc.jail_attach(jid) + if error_code != 0: + raise Exception("jail_attach() failed: errno {}".format(error_code)) + + def attach(self): + self.attach_jid(self.jid) + self.attached = True + + +class VnetFactory(object): + JAILS_FNAME = "created_jails.lst" + + def __init__(self, topology_id: str): + self.topology_id = topology_id + self.file_name = self.JAILS_FNAME + self._vnets: List[str] = [] + + def _register_vnet(self, vnet_name: str): + self._vnets.append(vnet_name) + with open(self.file_name, "a") as f: + f.write(vnet_name + "\n") + + @staticmethod + def _wait_interfaces(vnet_name: str, ifaces: List[str]) -> List[str]: + cmd = "/usr/sbin/jexec {} /sbin/ifconfig -l".format(vnet_name) + not_matched: List[str] = [] + for i in range(50): + vnet_ifaces = run_cmd(cmd).strip().split(" ") + not_matched = [] + for iface_name in ifaces: + if iface_name not in vnet_ifaces: + not_matched.append(iface_name) + if len(not_matched) == 0: + return [] + time.sleep(0.1) + return not_matched + + def create_vnet(self, vnet_alias: str, ifaces: List[VnetInterface], opts: List[str]): + vnet_name = "pytest:{}".format(convert_test_name(self.topology_id)) + if self._vnets: + # add number to distinguish jails + vnet_name = "{}_{}".format(vnet_name, len(self._vnets) + 1) + iface_cmds = " ".join(["vnet.interface={}".format(i.name) for i in ifaces]) + opt_cmds = " ".join(["{}".format(i) for i in opts]) + cmd = "/usr/sbin/jail -i -c name={} persist vnet {} {}".format( + vnet_name, iface_cmds, opt_cmds + ) + jid = 0 + try: + jid_str = run_cmd(cmd) + jid = int(jid_str) + except ValueError: + print("Jail creation failed, output: {}".format(jid_str)) + raise + self._register_vnet(vnet_name) + + # Run expedited version of routing + VnetInterface.setup_loopback(vnet_name) + + not_found = self._wait_interfaces(vnet_name, [i.name for i in ifaces]) + if not_found: + raise Exception( + "Interfaces {} has not appeared in vnet {}".format(not_found, vnet_name) + ) + return VnetInstance(vnet_alias, vnet_name, jid, ifaces) + + def cleanup(self): + iface_factory = IfaceFactory() + try: + with open(self.file_name) as f: + for line in f: + vnet_name = line.strip() + iface_factory.cleanup_vnet_interfaces(vnet_name) + run_cmd("/usr/sbin/jail -r {}".format(vnet_name)) + os.unlink(self.JAILS_FNAME) + except OSError: + pass + + +class SingleInterfaceMap(NamedTuple): + ifaces: List[VnetInterface] + vnet_aliases: List[str] + + +class ObjectsMap(NamedTuple): + iface_map: Dict[str, SingleInterfaceMap] # keyed by ifX + vnet_map: Dict[str, VnetInstance] # keyed by vnetX + topo_map: Dict # self.TOPOLOGY + + +class VnetTestTemplate(BaseTest): + NEED_ROOT: bool = True + TOPOLOGY = {} + + def _require_default_modules(self): + libc.kldload("if_epair.ko") + self.require_module("if_epair") + + def _get_vnet_handler(self, vnet_alias: str): + handler_name = "{}_handler".format(vnet_alias) + return getattr(self, handler_name, None) + + def _setup_vnet(self, vnet: VnetInstance, obj_map: Dict, pipe): + """Base Handler to setup given VNET. + Can be run in a subprocess. If so, passes control to the special + vnetX_handler() after setting up interface addresses + """ + vnet.attach() + print("# setup_vnet({})".format(vnet.name)) + if pipe is not None: + vnet.set_pipe(pipe) + + topo = obj_map.topo_map + ipv6_ifaces = [] + # Disable DAD + if not vnet.need_dad: + vnet.disable_dad() + for iface in vnet.ifaces: + # check index of vnet within an interface + # as we have prefixes for both ends of the interface + iface_map = obj_map.iface_map[iface.alias] + idx = iface_map.vnet_aliases.index(vnet.alias) + prefixes6 = topo[iface.alias].get("prefixes6", []) + prefixes4 = topo[iface.alias].get("prefixes4", []) + if prefixes6 or prefixes4: + ipv6_ifaces.append(iface) + iface.turn_up() + if prefixes6: + iface.enable_ipv6() + for prefix in prefixes6 + prefixes4: + if prefix[idx]: + iface.setup_addr(prefix[idx]) + for iface in ipv6_ifaces: + while iface.has_tentative(): + time.sleep(0.1) + + # Run actual handler + handler = self._get_vnet_handler(vnet.alias) + if handler: + # Do unbuffered stdout for children + # so the logs are present if the child hangs + sys.stdout.reconfigure(line_buffering=True) + self.drop_privileges() + handler(vnet) + + def _get_topo_ifmap(self, topo: Dict): + iface_factory = IfaceFactory() + iface_map: Dict[str, SingleInterfaceMap] = {} + iface_aliases = set() + for obj_name, obj_data in topo.items(): + if obj_name.startswith("vnet"): + for iface_alias in obj_data["ifaces"]: + iface_aliases.add(iface_alias) + for iface_alias in iface_aliases: + print("Creating {}".format(iface_alias)) + iface_data = topo[iface_alias] + iface_type = iface_data.get("type", "epair") + ifaces = iface_factory.create_iface(iface_alias, iface_type) + smap = SingleInterfaceMap(ifaces, []) + iface_map[iface_alias] = smap + return iface_map + + def setup_topology(self, topo: Dict, topology_id: str): + """Creates jails & interfaces for the provided topology""" + vnet_map = {} + vnet_factory = VnetFactory(topology_id) + iface_map = self._get_topo_ifmap(topo) + for obj_name, obj_data in topo.items(): + if obj_name.startswith("vnet"): + vnet_ifaces = [] + for iface_alias in obj_data["ifaces"]: + # epair creates 2 interfaces, grab first _available_ + # and map it to the VNET being created + idx = len(iface_map[iface_alias].vnet_aliases) + iface_map[iface_alias].vnet_aliases.append(obj_name) + vnet_ifaces.append(iface_map[iface_alias].ifaces[idx]) + opts = [] + if "opts" in obj_data: + opts = obj_data["opts"] + vnet = vnet_factory.create_vnet(obj_name, vnet_ifaces, opts) + vnet_map[obj_name] = vnet + # Allow reference to VNETs as attributes + setattr(self, obj_name, vnet) + # Debug output + print("============= TEST TOPOLOGY =============") + for vnet_alias, vnet in vnet_map.items(): + print("# vnet {} -> {}".format(vnet.alias, vnet.name), end="") + handler = self._get_vnet_handler(vnet.alias) + if handler: + print(" handler: {}".format(handler.__name__), end="") + print() + for iface_alias, iface_data in iface_map.items(): + vnets = iface_data.vnet_aliases + ifaces: List[VnetInterface] = iface_data.ifaces + if len(vnets) == 1 and len(ifaces) == 2: + print( + "# iface {}: {}::{} -> main::{}".format( + iface_alias, vnets[0], ifaces[0].name, ifaces[1].name + ) + ) + elif len(vnets) == 2 and len(ifaces) == 2: + print( + "# iface {}: {}::{} -> {}::{}".format( + iface_alias, vnets[0], ifaces[0].name, vnets[1], ifaces[1].name + ) + ) + else: + print( + "# iface {}: ifaces: {} vnets: {}".format( + iface_alias, vnets, [i.name for i in ifaces] + ) + ) + print() + return ObjectsMap(iface_map, vnet_map, topo) + + def setup_method(self, _method): + """Sets up all the required topology and handlers for the given test""" + super().setup_method(_method) + self._require_default_modules() + + # TestIP6Output.test_output6_pktinfo[ipandif] + topology_id = get_topology_id(self.test_id) + topology = self.TOPOLOGY + # First, setup kernel objects - interfaces & vnets + obj_map = self.setup_topology(topology, topology_id) + main_vnet = None # one without subprocess handler + for vnet_alias, vnet in obj_map.vnet_map.items(): + if self._get_vnet_handler(vnet_alias): + # Need subprocess to run + parent_pipe, child_pipe = Pipe() + p = Process( + target=self._setup_vnet, + args=( + vnet, + obj_map, + child_pipe, + ), + ) + vnet.set_pipe(parent_pipe) + vnet.set_subprocess(p) + p.start() + else: + if main_vnet is not None: + raise Exception("there can be only 1 VNET w/o handler") + main_vnet = vnet + # Main vnet needs to be the last, so all the other subprocesses + # are started & their pipe handles collected + self.vnet = main_vnet + self._setup_vnet(main_vnet, obj_map, None) + # Save state for the main handler + self.iface_map = obj_map.iface_map + self.vnet_map = obj_map.vnet_map + self.drop_privileges() + + def cleanup(self, test_id: str): + # pytest test id: file::class::test_name + topology_id = get_topology_id(self.test_id) + + print("============= vnet cleanup =============") + print("# topology_id: '{}'".format(topology_id)) + VnetFactory(topology_id).cleanup() + IfaceFactory().cleanup() + + def wait_object(self, pipe, timeout=5): + if pipe.poll(timeout): + return pipe.recv() + raise TimeoutError + + def wait_objects_any(self, pipe_list, timeout=5): + objects = connection.wait(pipe_list, timeout) + if objects: + return objects[0].recv() + raise TimeoutError + + def send_object(self, pipe, obj): + pipe.send(obj) + + def wait(self): + while True: + time.sleep(1) + + @property + def curvnet(self): + pass + + +class SingleVnetTestTemplate(VnetTestTemplate): + IPV6_PREFIXES: List[str] = [] + IPV4_PREFIXES: List[str] = [] + IFTYPE = "epair" + + def _setup_default_topology(self): + topology = copy.deepcopy( + { + "vnet1": {"ifaces": ["if1"]}, + "if1": {"type": self.IFTYPE, "prefixes4": [], "prefixes6": []}, + } + ) + for prefix in self.IPV6_PREFIXES: + topology["if1"]["prefixes6"].append((prefix,)) + for prefix in self.IPV4_PREFIXES: + topology["if1"]["prefixes4"].append((prefix,)) + return topology + + def setup_method(self, method): + if not getattr(self, "TOPOLOGY", None): + self.TOPOLOGY = self._setup_default_topology() + else: + names = self.TOPOLOGY.keys() + assert len([n for n in names if n.startswith("vnet")]) == 1 + super().setup_method(method) diff --git a/tests/atf_python/sys/netlink/Makefile b/tests/atf_python/sys/netlink/Makefile new file mode 100644 index 000000000000..6a40a93f3ae9 --- /dev/null +++ b/tests/atf_python/sys/netlink/Makefile @@ -0,0 +1,13 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +PACKAGE=tests +FILES= __init__.py attrs.py base_headers.py message.py netlink.py \ + netlink_generic.py netlink_route.py utils.py + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python/sys/netlink + +.include <bsd.prog.mk> + diff --git a/tests/atf_python/sys/netlink/__init__.py b/tests/atf_python/sys/netlink/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 --- /dev/null +++ b/tests/atf_python/sys/netlink/__init__.py diff --git a/tests/atf_python/sys/netlink/attrs.py b/tests/atf_python/sys/netlink/attrs.py new file mode 100644 index 000000000000..36dd8191df1c --- /dev/null +++ b/tests/atf_python/sys/netlink/attrs.py @@ -0,0 +1,335 @@ +import socket +import struct +from enum import Enum + +from atf_python.sys.netlink.utils import align4 +from atf_python.sys.netlink.utils import enum_or_int + + +class NlAttr(object): + HDR_LEN = 4 # sizeof(struct nlattr) + + def __init__(self, nla_type, data): + if isinstance(nla_type, Enum): + self._nla_type = nla_type.value + self._enum = nla_type + else: + self._nla_type = nla_type + self._enum = None + self.nla_list = [] + self._data = data + + @property + def nla_type(self): + return self._nla_type & 0x3FFF + + @property + def nla_len(self): + return len(self._data) + 4 + + def add_nla(self, nla): + self.nla_list.append(nla) + + def print_attr(self, prepend=""): + if self._enum is not None: + type_str = self._enum.name + else: + type_str = "nla#{}".format(self.nla_type) + print( + "{}len={} type={}({}){}".format( + prepend, self.nla_len, type_str, self.nla_type, self._print_attr_value() + ) + ) + + @staticmethod + def _validate(data): + if len(data) < 4: + raise ValueError("attribute too short") + nla_len, nla_type = struct.unpack("@HH", data[:4]) + if nla_len > len(data): + raise ValueError("attribute length too big") + if nla_len < 4: + raise ValueError("attribute length too short") + + @classmethod + def _parse(cls, data): + nla_len, nla_type = struct.unpack("@HH", data[:4]) + return cls(nla_type, data[4:]) + + @classmethod + def from_bytes(cls, data, attr_type_enum=None): + cls._validate(data) + attr = cls._parse(data) + attr._enum = attr_type_enum + return attr + + def _to_bytes(self, data: bytes): + ret = data + if align4(len(ret)) != len(ret): + ret = data + bytes(align4(len(ret)) - len(ret)) + return struct.pack("@HH", len(data) + 4, self._nla_type) + ret + + def __bytes__(self): + return self._to_bytes(self._data) + + def _print_attr_value(self): + return " " + " ".join(["x{:02X}".format(b) for b in self._data]) + + +class NlAttrNested(NlAttr): + def __init__(self, nla_type, val): + super().__init__(nla_type, b"") + self.nla_list = val + + def get_nla(self, nla_type): + nla_type_raw = enum_or_int(nla_type) + for nla in self.nla_list: + if nla.nla_type == nla_type_raw: + return nla + return None + + @property + def nla_len(self): + return align4(len(b"".join([bytes(nla) for nla in self.nla_list]))) + 4 + + def print_attr(self, prepend=""): + if self._enum is not None: + type_str = self._enum.name + else: + type_str = "nla#{}".format(self.nla_type) + print( + "{}len={} type={}({}) {{".format( + prepend, self.nla_len, type_str, self.nla_type + ) + ) + for nla in self.nla_list: + nla.print_attr(prepend + " ") + print("{}}}".format(prepend)) + + def __bytes__(self): + return self._to_bytes(b"".join([bytes(nla) for nla in self.nla_list])) + + +class NlAttrU32(NlAttr): + def __init__(self, nla_type, val): + self.u32 = enum_or_int(val) + super().__init__(nla_type, b"") + + @property + def nla_len(self): + return 8 + + def _print_attr_value(self): + return " val={}".format(self.u32) + + @staticmethod + def _validate(data): + assert len(data) == 8 + nla_len, nla_type = struct.unpack("@HH", data[:4]) + assert nla_len == 8 + + @classmethod + def _parse(cls, data): + nla_len, nla_type, val = struct.unpack("@HHI", data) + return cls(nla_type, val) + + def __bytes__(self): + return self._to_bytes(struct.pack("@I", self.u32)) + + +class NlAttrS32(NlAttr): + def __init__(self, nla_type, val): + self.s32 = enum_or_int(val) + super().__init__(nla_type, b"") + + @property + def nla_len(self): + return 8 + + def _print_attr_value(self): + return " val={}".format(self.s32) + + @staticmethod + def _validate(data): + assert len(data) == 8 + nla_len, nla_type = struct.unpack("@HH", data[:4]) + assert nla_len == 8 + + @classmethod + def _parse(cls, data): + nla_len, nla_type, val = struct.unpack("@HHi", data) + return cls(nla_type, val) + + def __bytes__(self): + return self._to_bytes(struct.pack("@i", self.s32)) + + +class NlAttrU16(NlAttr): + def __init__(self, nla_type, val): + self.u16 = enum_or_int(val) + super().__init__(nla_type, b"") + + @property + def nla_len(self): + return 6 + + def _print_attr_value(self): + return " val={}".format(self.u16) + + @staticmethod + def _validate(data): + assert len(data) == 6 + nla_len, nla_type = struct.unpack("@HH", data[:4]) + assert nla_len == 6 + + @classmethod + def _parse(cls, data): + nla_len, nla_type, val = struct.unpack("@HHH", data) + return cls(nla_type, val) + + def __bytes__(self): + return self._to_bytes(struct.pack("@H", self.u16)) + + +class NlAttrU8(NlAttr): + def __init__(self, nla_type, val): + self.u8 = enum_or_int(val) + super().__init__(nla_type, b"") + + @property + def nla_len(self): + return 5 + + def _print_attr_value(self): + return " val={}".format(self.u8) + + @staticmethod + def _validate(data): + assert len(data) == 5 + nla_len, nla_type = struct.unpack("@HH", data[:4]) + assert nla_len == 5 + + @classmethod + def _parse(cls, data): + nla_len, nla_type, val = struct.unpack("@HHB", data) + return cls(nla_type, val) + + def __bytes__(self): + return self._to_bytes(struct.pack("@B", self.u8)) + + +class NlAttrIp(NlAttr): + def __init__(self, nla_type, addr: str): + super().__init__(nla_type, b"") + self.addr = addr + if ":" in self.addr: + self.family = socket.AF_INET6 + else: + self.family = socket.AF_INET + + @staticmethod + def _validate(data): + nla_len, nla_type = struct.unpack("@HH", data[:4]) + data_len = nla_len - 4 + if data_len != 4 and data_len != 16: + raise ValueError( + "Error validating attr {}: nla_len is not valid".format( # noqa: E501 + nla_type + ) + ) + + @property + def nla_len(self): + if self.family == socket.AF_INET6: + return 20 + else: + return 8 + return align4(len(self._data)) + 4 + + @classmethod + def _parse(cls, data): + nla_len, nla_type = struct.unpack("@HH", data[:4]) + data_len = len(data) - 4 + if data_len == 4: + addr = socket.inet_ntop(socket.AF_INET, data[4:8]) + else: + addr = socket.inet_ntop(socket.AF_INET6, data[4:20]) + return cls(nla_type, addr) + + def __bytes__(self): + return self._to_bytes(socket.inet_pton(self.family, self.addr)) + + def _print_attr_value(self): + return " addr={}".format(self.addr) + + +class NlAttrIp4(NlAttrIp): + def __init__(self, nla_type, addr: str): + super().__init__(nla_type, addr) + assert self.family == socket.AF_INET + + +class NlAttrIp6(NlAttrIp): + def __init__(self, nla_type, addr: str): + super().__init__(nla_type, addr) + assert self.family == socket.AF_INET6 + + +class NlAttrStr(NlAttr): + def __init__(self, nla_type, text): + super().__init__(nla_type, b"") + self.text = text + + @staticmethod + def _validate(data): + NlAttr._validate(data) + try: + data[4:].decode("utf-8") + except Exception as e: + raise ValueError("wrong utf-8 string: {}".format(e)) + + @property + def nla_len(self): + return len(self.text) + 5 + + @classmethod + def _parse(cls, data): + text = data[4:-1].decode("utf-8") + nla_len, nla_type = struct.unpack("@HH", data[:4]) + return cls(nla_type, text) + + def __bytes__(self): + return self._to_bytes(bytes(self.text, encoding="utf-8") + bytes(1)) + + def _print_attr_value(self): + return ' val="{}"'.format(self.text) + + +class NlAttrStrn(NlAttr): + def __init__(self, nla_type, text): + super().__init__(nla_type, b"") + self.text = text + + @staticmethod + def _validate(data): + NlAttr._validate(data) + try: + data[4:].decode("utf-8") + except Exception as e: + raise ValueError("wrong utf-8 string: {}".format(e)) + + @property + def nla_len(self): + return len(self.text) + 4 + + @classmethod + def _parse(cls, data): + text = data[4:].decode("utf-8") + nla_len, nla_type = struct.unpack("@HH", data[:4]) + return cls(nla_type, text) + + def __bytes__(self): + return self._to_bytes(bytes(self.text, encoding="utf-8")) + + def _print_attr_value(self): + return ' val="{}"'.format(self.text) diff --git a/tests/atf_python/sys/netlink/base_headers.py b/tests/atf_python/sys/netlink/base_headers.py new file mode 100644 index 000000000000..71771a249b3d --- /dev/null +++ b/tests/atf_python/sys/netlink/base_headers.py @@ -0,0 +1,72 @@ +from ctypes import c_ubyte +from ctypes import c_uint +from ctypes import c_ushort +from ctypes import Structure +from enum import Enum + + +class Nlmsghdr(Structure): + _fields_ = [ + ("nlmsg_len", c_uint), + ("nlmsg_type", c_ushort), + ("nlmsg_flags", c_ushort), + ("nlmsg_seq", c_uint), + ("nlmsg_pid", c_uint), + ] + + +class Nlattr(Structure): + _fields_ = [ + ("nla_len", c_ushort), + ("nla_type", c_ushort), + ] + + +class NlMsgType(Enum): + NLMSG_NOOP = 1 + NLMSG_ERROR = 2 + NLMSG_DONE = 3 + NLMSG_OVERRUN = 4 + + +class NlmBaseFlags(Enum): + NLM_F_REQUEST = 0x01 + NLM_F_MULTI = 0x02 + NLM_F_ACK = 0x04 + NLM_F_ECHO = 0x08 + NLM_F_DUMP_INTR = 0x10 + NLM_F_DUMP_FILTERED = 0x20 + + +# XXX: in python3.8 it is possible to +# class NlmGetFlags(Enum, NlmBaseFlags): + + +class NlmGetFlags(Enum): + NLM_F_ROOT = 0x100 + NLM_F_MATCH = 0x200 + NLM_F_ATOMIC = 0x400 + + +class NlmNewFlags(Enum): + NLM_F_REPLACE = 0x100 + NLM_F_EXCL = 0x200 + NLM_F_CREATE = 0x400 + NLM_F_APPEND = 0x800 + + +class NlmDeleteFlags(Enum): + NLM_F_NONREC = 0x100 + + +class NlmAckFlags(Enum): + NLM_F_CAPPED = 0x100 + NLM_F_ACK_TLVS = 0x200 + + +class GenlMsgHdr(Structure): + _fields_ = [ + ("cmd", c_ubyte), + ("version", c_ubyte), + ("reserved", c_ushort), + ] diff --git a/tests/atf_python/sys/netlink/message.py b/tests/atf_python/sys/netlink/message.py new file mode 100644 index 000000000000..98a1e3bb21c5 --- /dev/null +++ b/tests/atf_python/sys/netlink/message.py @@ -0,0 +1,286 @@ +#!/usr/local/bin/python3 +import struct +from ctypes import sizeof +from enum import Enum +from typing import List +from typing import NamedTuple + +from atf_python.sys.netlink.attrs import NlAttr +from atf_python.sys.netlink.attrs import NlAttrNested +from atf_python.sys.netlink.base_headers import NlmAckFlags +from atf_python.sys.netlink.base_headers import NlmNewFlags +from atf_python.sys.netlink.base_headers import NlmGetFlags +from atf_python.sys.netlink.base_headers import NlmDeleteFlags +from atf_python.sys.netlink.base_headers import NlmBaseFlags +from atf_python.sys.netlink.base_headers import Nlmsghdr +from atf_python.sys.netlink.base_headers import NlMsgType +from atf_python.sys.netlink.utils import align4 +from atf_python.sys.netlink.utils import enum_or_int +from atf_python.sys.netlink.utils import get_bitmask_str + + +class NlMsgCategory(Enum): + UNKNOWN = 0 + GET = 1 + NEW = 2 + DELETE = 3 + ACK = 4 + + +class NlMsgProps(NamedTuple): + msg: Enum + category: NlMsgCategory + + +class BaseNetlinkMessage(object): + def __init__(self, helper, nlmsg_type): + self.nlmsg_type = enum_or_int(nlmsg_type) + self.nla_list = [] + self._orig_data = None + self.helper = helper + self.nl_hdr = Nlmsghdr( + nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid + ) + self.base_hdr = None + + def set_request(self, need_ack=True): + self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST]) + if need_ack: + self.add_nlflags([NlmBaseFlags.NLM_F_ACK]) + + def add_nlflags(self, flags: List): + int_flags = 0 + for flag in flags: + int_flags |= enum_or_int(flag) + self.nl_hdr.nlmsg_flags |= int_flags + + def add_nla(self, nla): + self.nla_list.append(nla) + + def _get_nla(self, nla_list, nla_type): + nla_type_raw = enum_or_int(nla_type) + for nla in nla_list: + if nla.nla_type == nla_type_raw: + return nla + return None + + def get_nla(self, nla_type): + return self._get_nla(self.nla_list, nla_type) + + @staticmethod + def parse_nl_header(data: bytes): + if len(data) < sizeof(Nlmsghdr): + raise ValueError("length less than netlink message header") + return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr) + + def is_type(self, nlmsg_type): + nlmsg_type_raw = enum_or_int(nlmsg_type) + return nlmsg_type_raw == self.nl_hdr.nlmsg_type + + def is_reply(self, hdr): + return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value + + @property + def msg_name(self): + return "msg#{}".format(self._get_msg_type()) + + def _get_nl_category(self): + if self.is_reply(self.nl_hdr): + return NlMsgCategory.ACK + return NlMsgCategory.UNKNOWN + + def get_nlm_flags_str(self): + category = self._get_nl_category() + flags = self.nl_hdr.nlmsg_flags + + if category == NlMsgCategory.UNKNOWN: + return self.helper.get_bitmask_str(NlmBaseFlags, flags) + elif category == NlMsgCategory.GET: + flags_enum = NlmGetFlags + elif category == NlMsgCategory.NEW: + flags_enum = NlmNewFlags + elif category == NlMsgCategory.DELETE: + flags_enum = NlmDeleteFlags + elif category == NlMsgCategory.ACK: + flags_enum = NlmAckFlags + return get_bitmask_str([NlmBaseFlags, flags_enum], flags) + + def print_nl_header(self, prepend=""): + # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 + hdr = self.nl_hdr + print( + "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( + prepend, + hdr.nlmsg_len, + self.msg_name, + self.get_nlm_flags_str(), + hdr.nlmsg_flags, + hdr.nlmsg_seq, + hdr.nlmsg_pid, + ) + ) + + @classmethod + def from_bytes(cls, helper, data): + try: + hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) + self = cls(helper, hdr.nlmsg_type) + self._orig_data = data + self.nl_hdr = hdr + except ValueError as e: + print("Failed to parse nl header: {}".format(e)) + cls.print_as_bytes(data) + raise + return self + + def print_message(self): + self.print_nl_header() + + @staticmethod + def print_as_bytes(data: bytes, descr: str): + print("===vv {} (len:{:3d}) vv===".format(descr, len(data))) + off = 0 + step = 16 + while off < len(data): + for i in range(step): + if off + i < len(data): + print(" {:02X}".format(data[off + i]), end="") + print("") + off += step + print("--------------------") + + +class StdNetlinkMessage(BaseNetlinkMessage): + nl_attrs_map = {} + + @classmethod + def from_bytes(cls, helper, data): + try: + hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) + self = cls(helper, hdr.nlmsg_type) + self._orig_data = data + self.nl_hdr = hdr + except ValueError as e: + print("Failed to parse nl header: {}".format(e)) + cls.print_as_bytes(data) + raise + + offset = align4(hdrlen) + try: + base_hdr, hdrlen = self.parse_base_header(data[offset:]) + self.base_hdr = base_hdr + offset += align4(hdrlen) + # XXX: CAP_ACK + except ValueError as e: + print("Failed to parse nl rt header: {}".format(e)) + cls.print_as_bytes(data) + raise + + orig_offset = offset + try: + nla_list, nla_len = self.parse_nla_list(data[offset:]) + offset += nla_len + if offset != len(data): + raise ValueError( + "{} bytes left at the end of the packet".format(len(data) - offset) + ) # noqa: E501 + self.nla_list = nla_list + except ValueError as e: + print( + "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e) + ) # noqa: E501 + cls.print_as_bytes(data, "msg dump") + cls.print_as_bytes(data[orig_offset:], "failed block") + raise + return self + + def parse_child(self, data: bytes, attr_key, attr_map): + attrs, _ = self.parse_attrs(data, attr_map) + return NlAttrNested(attr_key, attrs) + + def parse_child_array(self, data: bytes, attr_key, attr_map): + ret = [] + off = 0 + while len(data) - off >= 4: + nla_len, raw_nla_type = struct.unpack("@HH", data[off : off + 4]) + if nla_len + off > len(data): + raise ValueError( + "attr length {} > than the remaining length {}".format( + nla_len, len(data) - off + ) + ) + nla_type = raw_nla_type & 0x3FFF + val = self.parse_child(data[off + 4 : off + nla_len], nla_type, attr_map) + ret.append(val) + off += align4(nla_len) + return NlAttrNested(attr_key, ret) + + def parse_attrs(self, data: bytes, attr_map): + ret = [] + off = 0 + while len(data) - off >= 4: + nla_len, raw_nla_type = struct.unpack("@HH", data[off : off + 4]) + if nla_len + off > len(data): + raise ValueError( + "attr length {} > than the remaining length {}".format( + nla_len, len(data) - off + ) + ) + nla_type = raw_nla_type & 0x3FFF + if nla_type in attr_map: + v = attr_map[nla_type] + val = v["ad"].cls.from_bytes(data[off : off + nla_len], v["ad"].val) + if "child" in v: + # nested + child_data = data[off + 4 : off + nla_len] + if v.get("is_array", False): + # Array of nested attributes + val = self.parse_child_array( + child_data, v["ad"].val, v["child"] + ) + else: + val = self.parse_child(child_data, v["ad"].val, v["child"]) + else: + # unknown attribute + val = NlAttr(raw_nla_type, data[off + 4 : off + nla_len]) + ret.append(val) + off += align4(nla_len) + return ret, off + + def parse_nla_list(self, data: bytes) -> List[NlAttr]: + return self.parse_attrs(data, self.nl_attrs_map) + + def __bytes__(self): + ret = bytes() + for nla in self.nla_list: + ret += bytes(nla) + ret = bytes(self.base_hdr) + ret + self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr) + return bytes(self.nl_hdr) + ret + + def _get_msg_type(self): + return self.nl_hdr.nlmsg_type + + @property + def msg_props(self): + msg_type = self._get_msg_type() + for msg_props in self.messages: + if msg_props.msg.value == msg_type: + return msg_props + return None + + @property + def msg_name(self): + msg_props = self.msg_props + if msg_props is not None: + return msg_props.msg.name + return super().msg_name + + def print_base_header(self, hdr, prepend=""): + pass + + def print_message(self): + self.print_nl_header() + self.print_base_header(self.base_hdr, " ") + for nla in self.nla_list: + nla.print_attr(" ") diff --git a/tests/atf_python/sys/netlink/netlink.py b/tests/atf_python/sys/netlink/netlink.py new file mode 100644 index 000000000000..f8f886b09b24 --- /dev/null +++ b/tests/atf_python/sys/netlink/netlink.py @@ -0,0 +1,417 @@ +#!/usr/local/bin/python3 +import os +import socket +import sys +from ctypes import c_int +from ctypes import c_ubyte +from ctypes import c_uint +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from enum import auto +from enum import Enum + +from atf_python.sys.netlink.attrs import NlAttr +from atf_python.sys.netlink.attrs import NlAttrStr +from atf_python.sys.netlink.attrs import NlAttrU32 +from atf_python.sys.netlink.base_headers import GenlMsgHdr +from atf_python.sys.netlink.base_headers import NlmBaseFlags +from atf_python.sys.netlink.base_headers import Nlmsghdr +from atf_python.sys.netlink.base_headers import NlMsgType +from atf_python.sys.netlink.message import BaseNetlinkMessage +from atf_python.sys.netlink.message import NlMsgCategory +from atf_python.sys.netlink.message import NlMsgProps +from atf_python.sys.netlink.message import StdNetlinkMessage +from atf_python.sys.netlink.netlink_generic import GenlCtrlAttrType +from atf_python.sys.netlink.netlink_generic import GenlCtrlMsgType +from atf_python.sys.netlink.netlink_generic import handler_classes as genl_classes +from atf_python.sys.netlink.netlink_route import handler_classes as rt_classes +from atf_python.sys.netlink.utils import align4 +from atf_python.sys.netlink.utils import AttrDescr +from atf_python.sys.netlink.utils import build_propmap +from atf_python.sys.netlink.utils import enum_or_int +from atf_python.sys.netlink.utils import get_bitmask_map +from atf_python.sys.netlink.utils import NlConst +from atf_python.sys.netlink.utils import prepare_attrs_map + + +class SockaddrNl(Structure): + _fields_ = [ + ("nl_len", c_ubyte), + ("nl_family", c_ubyte), + ("nl_pad", c_ushort), + ("nl_pid", c_uint), + ("nl_groups", c_uint), + ] + + +class Nlmsgdone(Structure): + _fields_ = [ + ("error", c_int), + ] + + +class Nlmsgerr(Structure): + _fields_ = [ + ("error", c_int), + ("msg", Nlmsghdr), + ] + + +class NlErrattrType(Enum): + NLMSGERR_ATTR_UNUSED = 0 + NLMSGERR_ATTR_MSG = auto() + NLMSGERR_ATTR_OFFS = auto() + NLMSGERR_ATTR_COOKIE = auto() + NLMSGERR_ATTR_POLICY = auto() + + +class AddressFamilyLinux(Enum): + AF_INET = socket.AF_INET + AF_INET6 = socket.AF_INET6 + AF_NETLINK = 16 + + +class AddressFamilyBsd(Enum): + AF_INET = socket.AF_INET + AF_INET6 = socket.AF_INET6 + AF_NETLINK = 38 + + +class NlHelper: + def __init__(self): + self._pmap = {} + self._af_cls = self.get_af_cls() + self._seq_counter = 1 + self.pid = os.getpid() + + def get_seq(self): + ret = self._seq_counter + self._seq_counter += 1 + return ret + + def get_af_cls(self): + if sys.platform.startswith("freebsd"): + cls = AddressFamilyBsd + else: + cls = AddressFamilyLinux + return cls + + def get_propmap(self, cls): + if cls not in self._pmap: + self._pmap[cls] = build_propmap(cls) + return self._pmap[cls] + + def get_name_propmap(self, cls): + ret = {} + for prop in dir(cls): + if not prop.startswith("_"): + ret[prop] = getattr(cls, prop).value + return ret + + def get_attr_byval(self, cls, attr_val): + propmap = self.get_propmap(cls) + return propmap.get(attr_val) + + def get_af_name(self, family): + v = self.get_attr_byval(self._af_cls, family) + if v is not None: + return v + return "af#{}".format(family) + + def get_af_value(self, family_str: str) -> int: + propmap = self.get_name_propmap(self._af_cls) + return propmap.get(family_str) + + def get_bitmask_str(self, cls, val): + bmap = get_bitmask_map(self.get_propmap(cls), val) + return ",".join([v for k, v in bmap.items()]) + + @staticmethod + def get_bitmask_str_uncached(cls, val): + pmap = NlHelper.build_propmap(cls) + bmap = NlHelper.get_bitmask_map(pmap, val) + return ",".join([v for k, v in bmap.items()]) + + +nldone_attrs = prepare_attrs_map([]) + +nlerr_attrs = prepare_attrs_map( + [ + AttrDescr(NlErrattrType.NLMSGERR_ATTR_MSG, NlAttrStr), + AttrDescr(NlErrattrType.NLMSGERR_ATTR_OFFS, NlAttrU32), + AttrDescr(NlErrattrType.NLMSGERR_ATTR_COOKIE, NlAttr), + ] +) + + +class NetlinkDoneMessage(StdNetlinkMessage): + messages = [NlMsgProps(NlMsgType.NLMSG_DONE, NlMsgCategory.ACK)] + nl_attrs_map = nldone_attrs + + @property + def error_code(self): + return self.base_hdr.error + + def parse_base_header(self, data): + if len(data) < sizeof(Nlmsgdone): + raise ValueError("length less than nlmsgdone header") + done_hdr = Nlmsgdone.from_buffer_copy(data) + sz = sizeof(Nlmsgdone) + return (done_hdr, sz) + + def print_base_header(self, hdr, prepend=""): + print("{}error={}".format(prepend, hdr.error)) + + +class NetlinkErrorMessage(StdNetlinkMessage): + messages = [NlMsgProps(NlMsgType.NLMSG_ERROR, NlMsgCategory.ACK)] + nl_attrs_map = nlerr_attrs + + @property + def error_code(self): + return self.base_hdr.error + + @property + def error_str(self): + nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_MSG) + if nla: + return nla.text + return None + + @property + def error_offset(self): + nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_OFFS) + if nla: + return nla.u32 + return None + + @property + def cookie(self): + return self.get_nla(NlErrattrType.NLMSGERR_ATTR_COOKIE) + + def parse_base_header(self, data): + if len(data) < sizeof(Nlmsgerr): + raise ValueError("length less than nlmsgerr header") + err_hdr = Nlmsgerr.from_buffer_copy(data) + sz = sizeof(Nlmsgerr) + if (self.nl_hdr.nlmsg_flags & 0x100) == 0: + sz += align4(err_hdr.msg.nlmsg_len - sizeof(Nlmsghdr)) + return (err_hdr, sz) + + def print_base_header(self, errhdr, prepend=""): + print("{}error={}, ".format(prepend, errhdr.error), end="") + hdr = errhdr.msg + print( + "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( + prepend, + hdr.nlmsg_len, + "msg#{}".format(hdr.nlmsg_type), + self.helper.get_bitmask_str(NlmBaseFlags, hdr.nlmsg_flags), + hdr.nlmsg_flags, + hdr.nlmsg_seq, + hdr.nlmsg_pid, + ) + ) + + +core_classes = { + "netlink_core": [ + NetlinkDoneMessage, + NetlinkErrorMessage, + ], +} + + +class Nlsock: + HANDLER_CLASSES = [core_classes, rt_classes, genl_classes] + + def __init__(self, family, helper): + self.helper = helper + self.sock_fd = self._setup_netlink(family) + self._sock_family = family + self._data = bytes() + self.msgmap = self.build_msgmap() + self._family_map = { + NlConst.GENL_ID_CTRL: "nlctrl", + } + + def build_msgmap(self): + handler_classes = {} + for d in self.HANDLER_CLASSES: + handler_classes.update(d) + xmap = {} + # 'family_name': [class.messages[MsgProps.msg], ] + for family_id, family_classes in handler_classes.items(): + xmap[family_id] = {} + for cls in family_classes: + for msg_props in cls.messages: + xmap[family_id][enum_or_int(msg_props.msg)] = cls + return xmap + + def _setup_netlink(self, netlink_family) -> int: + family = self.helper.get_af_value("AF_NETLINK") + s = socket.socket(family, socket.SOCK_RAW, netlink_family) + s.setsockopt(270, 10, 1) # NETLINK_CAP_ACK + s.setsockopt(270, 11, 1) # NETLINK_EXT_ACK + return s + + def set_groups(self, mask: int): + self.sock_fd.setsockopt(socket.SOL_SOCKET, 1, mask) + # snl = SockaddrNl(nl_len = sizeof(SockaddrNl), nl_family=38, + # nl_pid=self.pid, nl_groups=mask) + # xbuffer = create_string_buffer(sizeof(SockaddrNl)) + # memmove(xbuffer, addressof(snl), sizeof(SockaddrNl)) + # k = struct.pack("@BBHII", 12, 38, 0, self.pid, mask) + # self.sock_fd.bind(k) + + def join_group(self, group_id: int): + self.sock_fd.setsockopt(270, 1, group_id) + + def write_message(self, msg, verbose=True): + if verbose: + print("vvvvvvvv OUT vvvvvvvv") + msg.print_message() + msg_bytes = bytes(msg) + try: + ret = os.write(self.sock_fd.fileno(), msg_bytes) + assert ret == len(msg_bytes) + except Exception as e: + print("write({}) -> {}".format(len(msg_bytes), e)) + + def parse_message(self, data: bytes): + if len(data) < sizeof(Nlmsghdr): + raise Exception("Short read from nl: {} bytes".format(len(data))) + hdr = Nlmsghdr.from_buffer_copy(data) + if hdr.nlmsg_type < 16: + family_name = "netlink_core" + nlmsg_type = hdr.nlmsg_type + elif self._sock_family == NlConst.NETLINK_ROUTE: + family_name = "netlink_route" + nlmsg_type = hdr.nlmsg_type + else: + # Genetlink + if len(data) < sizeof(Nlmsghdr) + sizeof(GenlMsgHdr): + raise Exception("Short read from genl: {} bytes".format(len(data))) + family_name = self._family_map.get(hdr.nlmsg_type, "") + ghdr = GenlMsgHdr.from_buffer_copy(data[sizeof(Nlmsghdr):]) + nlmsg_type = ghdr.cmd + cls = self.msgmap.get(family_name, {}).get(nlmsg_type) + if not cls: + cls = BaseNetlinkMessage + return cls.from_bytes(self.helper, data) + + def get_genl_family_id(self, family_name): + hdr = Nlmsghdr( + nlmsg_type=NlConst.GENL_ID_CTRL, + nlmsg_flags=NlmBaseFlags.NLM_F_REQUEST.value, + nlmsg_seq=self.helper.get_seq(), + ) + ghdr = GenlMsgHdr(cmd=GenlCtrlMsgType.CTRL_CMD_GETFAMILY.value) + nla = NlAttrStr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, family_name) + hdr.nlmsg_len = sizeof(Nlmsghdr) + sizeof(GenlMsgHdr) + len(bytes(nla)) + + msg_bytes = bytes(hdr) + bytes(ghdr) + bytes(nla) + self.write_data(msg_bytes) + while True: + rx_msg = self.read_message() + if hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: + if rx_msg.is_type(NlMsgType.NLMSG_ERROR): + if rx_msg.error_code != 0: + raise ValueError("unable to get family {}".format(family_name)) + else: + family_id = rx_msg.get_nla(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID).u16 + self._family_map[family_id] = family_name + return family_id + raise ValueError("unable to get family {}".format(family_name)) + + def write_data(self, data: bytes): + self.sock_fd.send(data) + + def read_data(self): + while True: + data = self.sock_fd.recv(65535) + self._data += data + if len(self._data) >= sizeof(Nlmsghdr): + break + + def read_message(self) -> bytes: + if len(self._data) < sizeof(Nlmsghdr): + self.read_data() + hdr = Nlmsghdr.from_buffer_copy(self._data) + while hdr.nlmsg_len > len(self._data): + self.read_data() + raw_msg = self._data[: hdr.nlmsg_len] + self._data = self._data[hdr.nlmsg_len:] + return self.parse_message(raw_msg) + + def get_reply(self, tx_msg): + self.write_message(tx_msg) + while True: + rx_msg = self.read_message() + if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: + return rx_msg + + +class NetlinkMultipartIterator(object): + def __init__(self, obj, seq_number: int, msg_type): + self._obj = obj + self._seq = seq_number + self._msg_type = msg_type + + def __iter__(self): + return self + + def __next__(self): + msg = self._obj.read_message() + if self._seq != msg.nl_hdr.nlmsg_seq: + raise ValueError("bad sequence number") + if msg.is_type(NlMsgType.NLMSG_ERROR): + raise ValueError( + "error while handling multipart msg: {}".format(msg.error_code) + ) + elif msg.is_type(NlMsgType.NLMSG_DONE): + if msg.error_code == 0: + raise StopIteration + raise ValueError( + "error listing some parts of the multipart msg: {}".format( + msg.error_code + ) + ) + elif not msg.is_type(self._msg_type): + raise ValueError("bad message type: {}".format(msg)) + return msg + + +class NetlinkTestTemplate(object): + REQUIRED_MODULES = ["netlink"] + + def setup_netlink(self, netlink_family: NlConst): + self.helper = NlHelper() + self.nlsock = Nlsock(netlink_family, self.helper) + + def write_message(self, msg, silent=False): + if not silent: + print("") + print("============= >> TX MESSAGE =============") + msg.print_message() + msg.print_as_bytes(bytes(msg), "-- DATA --") + self.nlsock.write_data(bytes(msg)) + + def read_message(self, silent=False): + msg = self.nlsock.read_message() + if not silent: + print("") + print("============= << RX MESSAGE =============") + msg.print_message() + return msg + + def get_reply(self, tx_msg): + self.write_message(tx_msg) + while True: + rx_msg = self.read_message() + if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: + return rx_msg + + def read_msg_list(self, seq, msg_type): + return list(NetlinkMultipartIterator(self, seq, msg_type)) diff --git a/tests/atf_python/sys/netlink/netlink_generic.py b/tests/atf_python/sys/netlink/netlink_generic.py new file mode 100644 index 000000000000..80c6eea72a93 --- /dev/null +++ b/tests/atf_python/sys/netlink/netlink_generic.py @@ -0,0 +1,312 @@ +#!/usr/local/bin/python3 +import struct +from ctypes import c_int64 +from ctypes import c_long +from ctypes import sizeof +from ctypes import Structure +from enum import Enum + +from atf_python.sys.netlink.attrs import NlAttr +from atf_python.sys.netlink.attrs import NlAttrIp4 +from atf_python.sys.netlink.attrs import NlAttrIp6 +from atf_python.sys.netlink.attrs import NlAttrNested +from atf_python.sys.netlink.attrs import NlAttrS32 +from atf_python.sys.netlink.attrs import NlAttrStr +from atf_python.sys.netlink.attrs import NlAttrU16 +from atf_python.sys.netlink.attrs import NlAttrU32 +from atf_python.sys.netlink.attrs import NlAttrU8 +from atf_python.sys.netlink.base_headers import GenlMsgHdr +from atf_python.sys.netlink.message import NlMsgCategory +from atf_python.sys.netlink.message import NlMsgProps +from atf_python.sys.netlink.message import StdNetlinkMessage +from atf_python.sys.netlink.utils import AttrDescr +from atf_python.sys.netlink.utils import enum_or_int +from atf_python.sys.netlink.utils import prepare_attrs_map + + +class NetlinkGenlMessage(StdNetlinkMessage): + messages = [] + nl_attrs_map = {} + family_name = None + + def __init__(self, helper, family_id, cmd=0): + super().__init__(helper, family_id) + self.base_hdr = GenlMsgHdr(cmd=enum_or_int(cmd)) + + def parse_base_header(self, data): + if len(data) < sizeof(GenlMsgHdr): + raise ValueError("length less than GenlMsgHdr header") + ghdr = GenlMsgHdr.from_buffer_copy(data) + return (ghdr, sizeof(GenlMsgHdr)) + + def _get_msg_type(self): + return self.base_hdr.cmd + + def print_nl_header(self, prepend=""): + # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 + hdr = self.nl_hdr + print( + "{}len={}, family={}, flags={}(0x{:X}), seq={}, pid={}".format( + prepend, + hdr.nlmsg_len, + self.family_name, + self.get_nlm_flags_str(), + hdr.nlmsg_flags, + hdr.nlmsg_seq, + hdr.nlmsg_pid, + ) + ) + + def print_base_header(self, hdr, prepend=""): + print( + "{}cmd={} version={} reserved={}".format( + prepend, self.msg_name, hdr.version, hdr.reserved + ) + ) + + +GenlCtrlFamilyName = "nlctrl" + + +class GenlCtrlMsgType(Enum): + CTRL_CMD_UNSPEC = 0 + CTRL_CMD_NEWFAMILY = 1 + CTRL_CMD_DELFAMILY = 2 + CTRL_CMD_GETFAMILY = 3 + CTRL_CMD_NEWOPS = 4 + CTRL_CMD_DELOPS = 5 + CTRL_CMD_GETOPS = 6 + CTRL_CMD_NEWMCAST_GRP = 7 + CTRL_CMD_DELMCAST_GRP = 8 + CTRL_CMD_GETMCAST_GRP = 9 + CTRL_CMD_GETPOLICY = 10 + + +class GenlCtrlAttrType(Enum): + CTRL_ATTR_FAMILY_ID = 1 + CTRL_ATTR_FAMILY_NAME = 2 + CTRL_ATTR_VERSION = 3 + CTRL_ATTR_HDRSIZE = 4 + CTRL_ATTR_MAXATTR = 5 + CTRL_ATTR_OPS = 6 + CTRL_ATTR_MCAST_GROUPS = 7 + CTRL_ATTR_POLICY = 8 + CTRL_ATTR_OP_POLICY = 9 + CTRL_ATTR_OP = 10 + + +class GenlCtrlAttrOpType(Enum): + CTRL_ATTR_OP_ID = 1 + CTRL_ATTR_OP_FLAGS = 2 + + +class GenlCtrlAttrMcastGroupsType(Enum): + CTRL_ATTR_MCAST_GRP_NAME = 1 + CTRL_ATTR_MCAST_GRP_ID = 2 + + +genl_ctrl_attrs = prepare_attrs_map( + [ + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID, NlAttrU16), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, NlAttrStr), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_VERSION, NlAttrU32), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_HDRSIZE, NlAttrU32), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_MAXATTR, NlAttrU32), + AttrDescr( + GenlCtrlAttrType.CTRL_ATTR_OPS, + NlAttrNested, + [ + AttrDescr(GenlCtrlAttrOpType.CTRL_ATTR_OP_ID, NlAttrU32), + AttrDescr(GenlCtrlAttrOpType.CTRL_ATTR_OP_FLAGS, NlAttrU32), + ], + True, + ), + AttrDescr( + GenlCtrlAttrType.CTRL_ATTR_MCAST_GROUPS, + NlAttrNested, + [ + AttrDescr( + GenlCtrlAttrMcastGroupsType.CTRL_ATTR_MCAST_GRP_NAME, NlAttrStr + ), + AttrDescr( + GenlCtrlAttrMcastGroupsType.CTRL_ATTR_MCAST_GRP_ID, NlAttrU32 + ), + ], + True, + ), + ] +) + + +class NetlinkGenlCtrlMessage(NetlinkGenlMessage): + messages = [ + NlMsgProps(GenlCtrlMsgType.CTRL_CMD_NEWFAMILY, NlMsgCategory.NEW), + NlMsgProps(GenlCtrlMsgType.CTRL_CMD_GETFAMILY, NlMsgCategory.GET), + NlMsgProps(GenlCtrlMsgType.CTRL_CMD_DELFAMILY, NlMsgCategory.DELETE), + ] + nl_attrs_map = genl_ctrl_attrs + family_name = GenlCtrlFamilyName + + +CarpFamilyName = "carp" + + +class CarpMsgType(Enum): + CARP_NL_CMD_UNSPEC = 0 + CARP_NL_CMD_GET = 1 + CARP_NL_CMD_SET = 2 + + +class CarpAttrType(Enum): + CARP_NL_UNSPEC = 0 + CARP_NL_VHID = 1 + CARP_NL_STATE = 2 + CARP_NL_ADVBASE = 3 + CARP_NL_ADVSKEW = 4 + CARP_NL_KEY = 5 + CARP_NL_IFINDEX = 6 + CARP_NL_ADDR = 7 + CARP_NL_ADDR6 = 8 + CARP_NL_IFNAME = 9 + + +carp_gen_attrs = prepare_attrs_map( + [ + AttrDescr(CarpAttrType.CARP_NL_VHID, NlAttrU32), + AttrDescr(CarpAttrType.CARP_NL_STATE, NlAttrU32), + AttrDescr(CarpAttrType.CARP_NL_ADVBASE, NlAttrS32), + AttrDescr(CarpAttrType.CARP_NL_ADVSKEW, NlAttrS32), + AttrDescr(CarpAttrType.CARP_NL_KEY, NlAttr), + AttrDescr(CarpAttrType.CARP_NL_IFINDEX, NlAttrU32), + AttrDescr(CarpAttrType.CARP_NL_ADDR, NlAttrIp4), + AttrDescr(CarpAttrType.CARP_NL_ADDR6, NlAttrIp6), + AttrDescr(CarpAttrType.CARP_NL_IFNAME, NlAttrStr), + ] +) + + +class CarpGenMessage(NetlinkGenlMessage): + messages = [ + NlMsgProps(CarpMsgType.CARP_NL_CMD_GET, NlMsgCategory.GET), + NlMsgProps(CarpMsgType.CARP_NL_CMD_SET, NlMsgCategory.NEW), + ] + nl_attrs_map = carp_gen_attrs + family_name = CarpFamilyName + + +KtestFamilyName = "ktest" + + +class KtestMsgType(Enum): + KTEST_CMD_UNSPEC = 0 + KTEST_CMD_LIST = 1 + KTEST_CMD_RUN = 2 + KTEST_CMD_NEWTEST = 3 + KTEST_CMD_NEWMESSAGE = 4 + + +class KtestAttrType(Enum): + KTEST_ATTR_MOD_NAME = 1 + KTEST_ATTR_TEST_NAME = 2 + KTEST_ATTR_TEST_DESCR = 3 + KTEST_ATTR_TEST_META = 4 + + +class KtestLogMsgType(Enum): + KTEST_MSG_START = 1 + KTEST_MSG_END = 2 + KTEST_MSG_LOG = 3 + KTEST_MSG_FAIL = 4 + + +class KtestMsgAttrType(Enum): + KTEST_MSG_ATTR_TS = 1 + KTEST_MSG_ATTR_FUNC = 2 + KTEST_MSG_ATTR_FILE = 3 + KTEST_MSG_ATTR_LINE = 4 + KTEST_MSG_ATTR_TEXT = 5 + KTEST_MSG_ATTR_LEVEL = 6 + KTEST_MSG_ATTR_META = 7 + + +class timespec(Structure): + _fields_ = [ + ("tv_sec", c_int64), + ("tv_nsec", c_long), + ] + + +class NlAttrTS(NlAttr): + DATA_LEN = sizeof(timespec) + + def __init__(self, nla_type, val): + self.ts = val + super().__init__(nla_type, b"") + + @property + def nla_len(self): + return NlAttr.HDR_LEN + self.DATA_LEN + + def _print_attr_value(self): + return " tv_sec={} tv_nsec={}".format(self.ts.tv_sec, self.ts.tv_nsec) + + @staticmethod + def _validate(data): + assert len(data) == NlAttr.HDR_LEN + NlAttrTS.DATA_LEN + nla_len, nla_type = struct.unpack("@HH", data[: NlAttr.HDR_LEN]) + assert nla_len == NlAttr.HDR_LEN + NlAttrTS.DATA_LEN + + @classmethod + def _parse(cls, data): + nla_len, nla_type = struct.unpack("@HH", data[: NlAttr.HDR_LEN]) + val = timespec.from_buffer_copy(data[NlAttr.HDR_LEN :]) + return cls(nla_type, val) + + def __bytes__(self): + return self._to_bytes(bytes(self.ts)) + + +ktest_info_attrs = prepare_attrs_map( + [ + AttrDescr(KtestAttrType.KTEST_ATTR_MOD_NAME, NlAttrStr), + AttrDescr(KtestAttrType.KTEST_ATTR_TEST_NAME, NlAttrStr), + AttrDescr(KtestAttrType.KTEST_ATTR_TEST_DESCR, NlAttrStr), + ] +) + + +ktest_msg_attrs = prepare_attrs_map( + [ + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_FUNC, NlAttrStr), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_FILE, NlAttrStr), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_LINE, NlAttrU32), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_TEXT, NlAttrStr), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_LEVEL, NlAttrU8), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_TS, NlAttrTS), + ] +) + + +class KtestInfoMessage(NetlinkGenlMessage): + messages = [ + NlMsgProps(KtestMsgType.KTEST_CMD_LIST, NlMsgCategory.GET), + NlMsgProps(KtestMsgType.KTEST_CMD_RUN, NlMsgCategory.NEW), + NlMsgProps(KtestMsgType.KTEST_CMD_NEWTEST, NlMsgCategory.NEW), + ] + nl_attrs_map = ktest_info_attrs + family_name = KtestFamilyName + + +class KtestMsgMessage(NetlinkGenlMessage): + messages = [ + NlMsgProps(KtestMsgType.KTEST_CMD_NEWMESSAGE, NlMsgCategory.NEW), + ] + nl_attrs_map = ktest_msg_attrs + family_name = KtestFamilyName + + +handler_classes = { + CarpFamilyName: [CarpGenMessage], + GenlCtrlFamilyName: [NetlinkGenlCtrlMessage], + KtestFamilyName: [KtestInfoMessage, KtestMsgMessage], +} diff --git a/tests/atf_python/sys/netlink/netlink_route.py b/tests/atf_python/sys/netlink/netlink_route.py new file mode 100644 index 000000000000..2cfeb57da13f --- /dev/null +++ b/tests/atf_python/sys/netlink/netlink_route.py @@ -0,0 +1,832 @@ +import socket +import struct +from ctypes import c_int +from ctypes import c_ubyte +from ctypes import c_uint +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from enum import auto +from enum import Enum + +from atf_python.sys.netlink.attrs import NlAttr +from atf_python.sys.netlink.attrs import NlAttrIp +from atf_python.sys.netlink.attrs import NlAttrNested +from atf_python.sys.netlink.attrs import NlAttrStr +from atf_python.sys.netlink.attrs import NlAttrU32 +from atf_python.sys.netlink.attrs import NlAttrU8 +from atf_python.sys.netlink.message import StdNetlinkMessage +from atf_python.sys.netlink.message import NlMsgProps +from atf_python.sys.netlink.message import NlMsgCategory +from atf_python.sys.netlink.utils import AttrDescr +from atf_python.sys.netlink.utils import get_bitmask_str +from atf_python.sys.netlink.utils import prepare_attrs_map + + +class RtattrType(Enum): + RTA_UNSPEC = 0 + RTA_DST = 1 + RTA_SRC = 2 + RTA_IIF = 3 + RTA_OIF = 4 + RTA_GATEWAY = 5 + RTA_PRIORITY = 6 + RTA_PREFSRC = 7 + RTA_METRICS = 8 + RTA_MULTIPATH = 9 + # RTA_PROTOINFO = 10 + RTA_KNH_ID = 10 + RTA_FLOW = 11 + RTA_CACHEINFO = 12 + RTA_SESSION = 13 + # RTA_MP_ALGO = 14 + RTA_RTFLAGS = 14 + RTA_TABLE = 15 + RTA_MARK = 16 + RTA_MFC_STATS = 17 + RTA_VIA = 18 + RTA_NEWDST = 19 + RTA_PREF = 20 + RTA_ENCAP_TYPE = 21 + RTA_ENCAP = 22 + RTA_EXPIRES = 23 + RTA_PAD = 24 + RTA_UID = 25 + RTA_TTL_PROPAGATE = 26 + RTA_IP_PROTO = 27 + RTA_SPORT = 28 + RTA_DPORT = 29 + RTA_NH_ID = 30 + + +class NlRtMsgType(Enum): + RTM_NEWLINK = 16 + RTM_DELLINK = 17 + RTM_GETLINK = 18 + RTM_SETLINK = 19 + RTM_NEWADDR = 20 + RTM_DELADDR = 21 + RTM_GETADDR = 22 + RTM_NEWROUTE = 24 + RTM_DELROUTE = 25 + RTM_GETROUTE = 26 + RTM_NEWNEIGH = 28 + RTM_DELNEIGH = 29 + RTM_GETNEIGH = 30 + RTM_NEWRULE = 32 + RTM_DELRULE = 33 + RTM_GETRULE = 34 + RTM_NEWQDISC = 36 + RTM_DELQDISC = 37 + RTM_GETQDISC = 38 + RTM_NEWTCLASS = 40 + RTM_DELTCLASS = 41 + RTM_GETTCLASS = 42 + RTM_NEWTFILTER = 44 + RTM_DELTFILTER = 45 + RTM_GETTFILTER = 46 + RTM_NEWACTION = 48 + RTM_DELACTION = 49 + RTM_GETACTION = 50 + RTM_NEWPREFIX = 52 + RTM_GETMULTICAST = 58 + RTM_GETANYCAST = 62 + RTM_NEWNEIGHTBL = 64 + RTM_GETNEIGHTBL = 66 + RTM_SETNEIGHTBL = 67 + RTM_NEWNDUSEROPT = 68 + RTM_NEWADDRLABEL = 72 + RTM_DELADDRLABEL = 73 + RTM_GETADDRLABEL = 74 + RTM_GETDCB = 78 + RTM_SETDCB = 79 + RTM_NEWNETCONF = 80 + RTM_GETNETCONF = 82 + RTM_NEWMDB = 84 + RTM_DELMDB = 85 + RTM_GETMDB = 86 + RTM_NEWNSID = 88 + RTM_DELNSID = 89 + RTM_GETNSID = 90 + RTM_NEWSTATS = 92 + RTM_GETSTATS = 94 + + +class RtAttr(Structure): + _fields_ = [ + ("rta_len", c_ushort), + ("rta_type", c_ushort), + ] + + +class RtMsgHdr(Structure): + _fields_ = [ + ("rtm_family", c_ubyte), + ("rtm_dst_len", c_ubyte), + ("rtm_src_len", c_ubyte), + ("rtm_tos", c_ubyte), + ("rtm_table", c_ubyte), + ("rtm_protocol", c_ubyte), + ("rtm_scope", c_ubyte), + ("rtm_type", c_ubyte), + ("rtm_flags", c_uint), + ] + + +class RtMsgFlags(Enum): + RTM_F_NOTIFY = 0x100 + RTM_F_CLONED = 0x200 + RTM_F_EQUALIZE = 0x400 + RTM_F_PREFIX = 0x800 + RTM_F_LOOKUP_TABLE = 0x1000 + RTM_F_FIB_MATCH = 0x2000 + RTM_F_OFFLOAD = 0x4000 + RTM_F_TRAP = 0x8000 + RTM_F_OFFLOAD_FAILED = 0x20000000 + + +class RtScope(Enum): + RT_SCOPE_UNIVERSE = 0 + RT_SCOPE_SITE = 200 + RT_SCOPE_LINK = 253 + RT_SCOPE_HOST = 254 + RT_SCOPE_NOWHERE = 255 + + +class RtType(Enum): + RTN_UNSPEC = 0 + RTN_UNICAST = auto() + RTN_LOCAL = auto() + RTN_BROADCAST = auto() + RTN_ANYCAST = auto() + RTN_MULTICAST = auto() + RTN_BLACKHOLE = auto() + RTN_UNREACHABLE = auto() + RTN_PROHIBIT = auto() + RTN_THROW = auto() + RTN_NAT = auto() + RTN_XRESOLVE = auto() + + +class RtProto(Enum): + RTPROT_UNSPEC = 0 + RTPROT_REDIRECT = 1 + RTPROT_KERNEL = 2 + RTPROT_BOOT = 3 + RTPROT_STATIC = 4 + RTPROT_GATED = 8 + RTPROT_RA = 9 + RTPROT_MRT = 10 + RTPROT_ZEBRA = 11 + RTPROT_BIRD = 12 + RTPROT_DNROUTED = 13 + RTPROT_XORP = 14 + RTPROT_NTK = 15 + RTPROT_DHCP = 16 + RTPROT_MROUTED = 17 + RTPROT_KEEPALIVED = 18 + RTPROT_BABEL = 42 + RTPROT_OPENR = 99 + RTPROT_BGP = 186 + RTPROT_ISIS = 187 + RTPROT_OSPF = 188 + RTPROT_RIP = 189 + RTPROT_EIGRP = 192 + + +class NlRtaxType(Enum): + RTAX_UNSPEC = 0 + RTAX_LOCK = auto() + RTAX_MTU = auto() + RTAX_WINDOW = auto() + RTAX_RTT = auto() + RTAX_RTTVAR = auto() + RTAX_SSTHRESH = auto() + RTAX_CWND = auto() + RTAX_ADVMSS = auto() + RTAX_REORDERING = auto() + RTAX_HOPLIMIT = auto() + RTAX_INITCWND = auto() + RTAX_FEATURES = auto() + RTAX_RTO_MIN = auto() + RTAX_INITRWND = auto() + RTAX_QUICKACK = auto() + RTAX_CC_ALGO = auto() + RTAX_FASTOPEN_NO_COOKIE = auto() + + +class RtFlagsBSD(Enum): + RTF_UP = 0x1 + RTF_GATEWAY = 0x2 + RTF_HOST = 0x4 + RTF_REJECT = 0x8 + RTF_DYNAMIC = 0x10 + RTF_MODIFIED = 0x20 + RTF_DONE = 0x40 + RTF_XRESOLVE = 0x200 + RTF_LLINFO = 0x400 + RTF_LLDATA = 0x400 + RTF_STATIC = 0x800 + RTF_BLACKHOLE = 0x1000 + RTF_PROTO2 = 0x4000 + RTF_PROTO1 = 0x8000 + RTF_PROTO3 = 0x40000 + RTF_FIXEDMTU = 0x80000 + RTF_PINNED = 0x100000 + RTF_LOCAL = 0x200000 + RTF_BROADCAST = 0x400000 + RTF_MULTICAST = 0x800000 + RTF_STICKY = 0x10000000 + RTF_RNH_LOCKED = 0x40000000 + RTF_GWFLAG_COMPAT = 0x80000000 + + +class NlRtGroup(Enum): + RTNLGRP_NONE = 0 + RTNLGRP_LINK = auto() + RTNLGRP_NOTIFY = auto() + RTNLGRP_NEIGH = auto() + RTNLGRP_TC = auto() + RTNLGRP_IPV4_IFADDR = auto() + RTNLGRP_IPV4_MROUTE = auto() + RTNLGRP_IPV4_ROUTE = auto() + RTNLGRP_IPV4_RULE = auto() + RTNLGRP_IPV6_IFADDR = auto() + RTNLGRP_IPV6_MROUTE = auto() + RTNLGRP_IPV6_ROUTE = auto() + RTNLGRP_IPV6_IFINFO = auto() + RTNLGRP_DECnet_IFADDR = auto() + RTNLGRP_NOP2 = auto() + RTNLGRP_DECnet_ROUTE = auto() + RTNLGRP_DECnet_RULE = auto() + RTNLGRP_NOP4 = auto() + RTNLGRP_IPV6_PREFIX = auto() + RTNLGRP_IPV6_RULE = auto() + RTNLGRP_ND_USEROPT = auto() + RTNLGRP_PHONET_IFADDR = auto() + RTNLGRP_PHONET_ROUTE = auto() + RTNLGRP_DCB = auto() + RTNLGRP_IPV4_NETCONF = auto() + RTNLGRP_IPV6_NETCONF = auto() + RTNLGRP_MDB = auto() + RTNLGRP_MPLS_ROUTE = auto() + RTNLGRP_NSID = auto() + RTNLGRP_MPLS_NETCONF = auto() + RTNLGRP_IPV4_MROUTE_R = auto() + RTNLGRP_IPV6_MROUTE_R = auto() + RTNLGRP_NEXTHOP = auto() + RTNLGRP_BRVLAN = auto() + + +class IfinfoMsg(Structure): + _fields_ = [ + ("ifi_family", c_ubyte), + ("__ifi_pad", c_ubyte), + ("ifi_type", c_ushort), + ("ifi_index", c_int), + ("ifi_flags", c_uint), + ("ifi_change", c_uint), + ] + + +class IflattrType(Enum): + IFLA_UNSPEC = 0 + IFLA_ADDRESS = 1 + IFLA_BROADCAST = 2 + IFLA_IFNAME = 3 + IFLA_MTU = 4 + IFLA_LINK = 5 + IFLA_QDISC = 6 + IFLA_STATS = 7 + IFLA_COST = 8 + IFLA_PRIORITY = 9 + IFLA_MASTER = 10 + IFLA_WIRELESS = 11 + IFLA_PROTINFO = 12 + IFLA_TXQLEN = 13 + IFLA_MAP = 14 + IFLA_WEIGHT = 15 + IFLA_OPERSTATE = 16 + IFLA_LINKMODE = 17 + IFLA_LINKINFO = 18 + IFLA_NET_NS_PID = 19 + IFLA_IFALIAS = 20 + IFLA_NUM_VF = 21 + IFLA_VFINFO_LIST = 22 + IFLA_STATS64 = 23 + IFLA_VF_PORTS = 24 + IFLA_PORT_SELF = 25 + IFLA_AF_SPEC = 26 + IFLA_GROUP = 27 + IFLA_NET_NS_FD = 28 + IFLA_EXT_MASK = 29 + IFLA_PROMISCUITY = 30 + IFLA_NUM_TX_QUEUES = 31 + IFLA_NUM_RX_QUEUES = 32 + IFLA_CARRIER = 33 + IFLA_PHYS_PORT_ID = 34 + IFLA_CARRIER_CHANGES = 35 + IFLA_PHYS_SWITCH_ID = 36 + IFLA_LINK_NETNSID = 37 + IFLA_PHYS_PORT_NAME = 38 + IFLA_PROTO_DOWN = 39 + IFLA_GSO_MAX_SEGS = 40 + IFLA_GSO_MAX_SIZE = 41 + IFLA_PAD = 42 + IFLA_XDP = 43 + IFLA_EVENT = 44 + IFLA_NEW_NETNSID = 45 + IFLA_IF_NETNSID = 46 + IFLA_CARRIER_UP_COUNT = 47 + IFLA_CARRIER_DOWN_COUNT = 48 + IFLA_NEW_IFINDEX = 49 + IFLA_MIN_MTU = 50 + IFLA_MAX_MTU = 51 + IFLA_PROP_LIST = 52 + IFLA_ALT_IFNAME = 53 + IFLA_PERM_ADDRESS = 54 + IFLA_PROTO_DOWN_REASON = 55 + IFLA_PARENT_DEV_NAME = 56 + IFLA_PARENT_DEV_BUS_NAME = 57 + IFLA_GRO_MAX_SIZE = 58 + IFLA_TSO_MAX_SEGS = 59 + IFLA_ALLMULTI = 60 + IFLA_DEVLINK_PORT = 61 + IFLA_GSO_IPV4_MAX_SIZE = 62 + IFLA_GRO_IPV4_MAX_SIZE = 63 + IFLA_FREEBSD = 64 + + +class IflafAttrType(Enum): + IFLAF_UNSPEC = 0 + IFLAF_ORIG_IFNAME = 1 + IFLAF_ORIG_HWADDR = 2 + + +class IflinkInfo(Enum): + IFLA_INFO_UNSPEC = 0 + IFLA_INFO_KIND = auto() + IFLA_INFO_DATA = auto() + IFLA_INFO_XSTATS = auto() + IFLA_INFO_SLAVE_KIND = auto() + IFLA_INFO_SLAVE_DATA = auto() + + +class IfLinkInfoDataVlan(Enum): + IFLA_VLAN_UNSPEC = 0 + IFLA_VLAN_ID = auto() + IFLA_VLAN_FLAGS = auto() + IFLA_VLAN_EGRESS_QOS = auto() + IFLA_VLAN_INGRESS_QOS = auto() + IFLA_VLAN_PROTOCOL = auto() + + +class IfaddrMsg(Structure): + _fields_ = [ + ("ifa_family", c_ubyte), + ("ifa_prefixlen", c_ubyte), + ("ifa_flags", c_ubyte), + ("ifa_scope", c_ubyte), + ("ifa_index", c_uint), + ] + + +class IfaAttrType(Enum): + IFA_UNSPEC = 0 + IFA_ADDRESS = 1 + IFA_LOCAL = 2 + IFA_LABEL = 3 + IFA_BROADCAST = 4 + IFA_ANYCAST = 5 + IFA_CACHEINFO = 6 + IFA_MULTICAST = 7 + IFA_FLAGS = 8 + IFA_RT_PRIORITY = 9 + IFA_TARGET_NETNSID = 10 + IFA_FREEBSD = 11 + + +class IfafAttrType(Enum): + IFAF_UNSPEC = 0 + IFAF_VHID = 1 + IFAF_FLAGS = 2 + + +class IfaCacheInfo(Structure): + _fields_ = [ + ("ifa_prefered", c_uint), # seconds till the end of the prefix considered preferred + ("ifa_valid", c_uint), # seconds till the end of the prefix considered valid + ("cstamp", c_uint), # creation time in 1ms intervals from the boot time + ("tstamp", c_uint), # update time in 1ms intervals from the boot time + ] + + +class IfaFlags(Enum): + IFA_F_TEMPORARY = 0x01 + IFA_F_NODAD = 0x02 + IFA_F_OPTIMISTIC = 0x04 + IFA_F_DADFAILED = 0x08 + IFA_F_HOMEADDRESS = 0x10 + IFA_F_DEPRECATED = 0x20 + IFA_F_TENTATIVE = 0x40 + IFA_F_PERMANENT = 0x80 + IFA_F_MANAGETEMPADDR = 0x100 + IFA_F_NOPREFIXROUTE = 0x200 + IFA_F_MCAUTOJOIN = 0x400 + IFA_F_STABLE_PRIVACY = 0x800 + + +class IfafFlags6(Enum): + IN6_IFF_ANYCAST = 0x01 + IN6_IFF_TENTATIVE = 0x02 + IN6_IFF_DUPLICATED = 0x04 + IN6_IFF_DETACHED = 0x08 + IN6_IFF_DEPRECATED = 0x10 + IN6_IFF_NODAD = 0x20 + IN6_IFF_AUTOCONF = 0x40 + IN6_IFF_TEMPORARY = 0x80 + IN6_IFF_PREFER_SOURCE = 0x100 + + +class NdMsg(Structure): + _fields_ = [ + ("ndm_family", c_ubyte), + ("ndm_pad1", c_ubyte), + ("ndm_pad2", c_ubyte), + ("ndm_ifindex", c_uint), + ("ndm_state", c_ushort), + ("ndm_flags", c_ubyte), + ("ndm_type", c_ubyte), + ] + + +class NdAttrType(Enum): + NDA_UNSPEC = 0 + NDA_DST = 1 + NDA_LLADDR = 2 + NDA_CACHEINFO = 3 + NDA_PROBES = 4 + NDA_VLAN = 5 + NDA_PORT = 6 + NDA_VNI = 7 + NDA_IFINDEX = 8 + NDA_MASTER = 9 + NDA_LINK_NETNSID = 10 + NDA_SRC_VNI = 11 + NDA_PROTOCOL = 12 + NDA_NH_ID = 13 + NDA_FDB_EXT_ATTRS = 14 + NDA_FLAGS_EXT = 15 + NDA_NDM_STATE_MASK = 16 + NDA_NDM_FLAGS_MASK = 17 + + +class NlAttrRtFlags(NlAttrU32): + def _print_attr_value(self): + s = get_bitmask_str(RtFlagsBSD, self.u32) + return " rtflags={}".format(s) + + +class NlAttrIfindex(NlAttrU32): + def _print_attr_value(self): + try: + ifname = socket.if_indextoname(self.u32) + return " iface={}(#{})".format(ifname, self.u32) + except OSError: + pass + return " iface=if#{}".format(self.u32) + + +class NlAttrTable(NlAttrU32): + def _print_attr_value(self): + return " rtable={}".format(self.u32) + + +class NlAttrNhId(NlAttrU32): + def _print_attr_value(self): + return " nh_id={}".format(self.u32) + + +class NlAttrKNhId(NlAttrU32): + def _print_attr_value(self): + return " knh_id={}".format(self.u32) + + +class NlAttrMac(NlAttr): + def _print_attr_value(self): + return ' mac="' + ":".join(["{:02X}".format(b) for b in self._data]) + '"' + + +class NlAttrIfStats(NlAttr): + def _print_attr_value(self): + return " stats={...}" + + +class NlAttrCacheInfo(NlAttr): + def __init__(self, nla_type, data): + super().__init__(nla_type, data) + self.ci = IfaCacheInfo.from_buffer_copy(data) + + @staticmethod + def _validate(data): + nla_len, nla_type = struct.unpack("@HH", data[:4]) + data_len = nla_len - 4 + if data_len != sizeof(IfaCacheInfo): + raise ValueError( + "Error validating attr {}: wrong size".format(nla_type) + ) # noqa: E501 + + def _print_attr_value(self): + return " ifa_prefered={} ifa_valid={} cstamp={} tstamp={}".format( + self.ci.ifa_prefered, self.ci.ifa_valid, self.ci.cstamp, self.ci.tstamp) + + +class NlAttrVia(NlAttr): + def __init__(self, nla_type, family, addr: str): + super().__init__(nla_type, b"") + self.addr = addr + self.family = family + + @staticmethod + def _validate(data): + nla_len, nla_type = struct.unpack("@HH", data[:4]) + data_len = nla_len - 4 + if data_len == 0: + raise ValueError( + "Error validating attr {}: empty data".format(nla_type) + ) # noqa: E501 + family = int(data_len[0]) + if family not in (socket.AF_INET, socket.AF_INET6): + raise ValueError( + "Error validating attr {}: unsupported AF {}".format( # noqa: E501 + nla_type, family + ) + ) + if family == socket.AF_INET: + expected_len = 1 + 4 + else: + expected_len = 1 + 16 + if data_len != expected_len: + raise ValueError( + "Error validating attr {}: expected len {} got {}".format( # noqa: E501 + nla_type, expected_len, data_len + ) + ) + + @property + def nla_len(self): + if self.family == socket.AF_INET6: + return 21 + else: + return 9 + + @classmethod + def _parse(cls, data): + nla_len, nla_type, family = struct.unpack("@HHB", data[:5]) + off = 5 + if family == socket.AF_INET: + addr = socket.inet_ntop(family, data[off:off + 4]) + else: + addr = socket.inet_ntop(family, data[off:off + 16]) + return cls(nla_type, family, addr) + + def __bytes__(self): + addr = socket.inet_pton(self.family, self.addr) + return self._to_bytes(struct.pack("@B", self.family) + addr) + + def _print_attr_value(self): + return " via={}".format(self.addr) + + +rtnl_route_attrs = prepare_attrs_map( + [ + AttrDescr(RtattrType.RTA_DST, NlAttrIp), + AttrDescr(RtattrType.RTA_SRC, NlAttrIp), + AttrDescr(RtattrType.RTA_IIF, NlAttrIfindex), + AttrDescr(RtattrType.RTA_OIF, NlAttrIfindex), + AttrDescr(RtattrType.RTA_GATEWAY, NlAttrIp), + AttrDescr(RtattrType.RTA_TABLE, NlAttrTable), + AttrDescr(RtattrType.RTA_PRIORITY, NlAttrU32), + AttrDescr(RtattrType.RTA_VIA, NlAttrVia), + AttrDescr(RtattrType.RTA_NH_ID, NlAttrNhId), + AttrDescr(RtattrType.RTA_KNH_ID, NlAttrKNhId), + AttrDescr(RtattrType.RTA_RTFLAGS, NlAttrRtFlags), + AttrDescr( + RtattrType.RTA_METRICS, + NlAttrNested, + [ + AttrDescr(NlRtaxType.RTAX_MTU, NlAttrU32), + ], + ), + ] +) + +rtnl_ifla_attrs = prepare_attrs_map( + [ + AttrDescr(IflattrType.IFLA_ADDRESS, NlAttrMac), + AttrDescr(IflattrType.IFLA_BROADCAST, NlAttrMac), + AttrDescr(IflattrType.IFLA_IFNAME, NlAttrStr), + AttrDescr(IflattrType.IFLA_MTU, NlAttrU32), + AttrDescr(IflattrType.IFLA_LINK, NlAttrU32), + AttrDescr(IflattrType.IFLA_PROMISCUITY, NlAttrU32), + AttrDescr(IflattrType.IFLA_OPERSTATE, NlAttrU8), + AttrDescr(IflattrType.IFLA_CARRIER, NlAttrU8), + AttrDescr(IflattrType.IFLA_IFALIAS, NlAttrStr), + AttrDescr(IflattrType.IFLA_STATS64, NlAttrIfStats), + AttrDescr(IflattrType.IFLA_NEW_IFINDEX, NlAttrU32), + AttrDescr( + IflattrType.IFLA_LINKINFO, + NlAttrNested, + [ + AttrDescr(IflinkInfo.IFLA_INFO_KIND, NlAttrStr), + AttrDescr(IflinkInfo.IFLA_INFO_DATA, NlAttr), + ], + ), + AttrDescr( + IflattrType.IFLA_FREEBSD, + NlAttrNested, + [ + AttrDescr(IflafAttrType.IFLAF_ORIG_HWADDR, NlAttrMac), + ], + ), + ] +) + +rtnl_ifa_attrs = prepare_attrs_map( + [ + AttrDescr(IfaAttrType.IFA_ADDRESS, NlAttrIp), + AttrDescr(IfaAttrType.IFA_LOCAL, NlAttrIp), + AttrDescr(IfaAttrType.IFA_LABEL, NlAttrStr), + AttrDescr(IfaAttrType.IFA_BROADCAST, NlAttrIp), + AttrDescr(IfaAttrType.IFA_ANYCAST, NlAttrIp), + AttrDescr(IfaAttrType.IFA_FLAGS, NlAttrU32), + AttrDescr(IfaAttrType.IFA_CACHEINFO, NlAttrCacheInfo), + AttrDescr( + IfaAttrType.IFA_FREEBSD, + NlAttrNested, + [ + AttrDescr(IfafAttrType.IFAF_VHID, NlAttrU32), + AttrDescr(IfafAttrType.IFAF_FLAGS, NlAttrU32), + ], + ), + ] +) + + +rtnl_nd_attrs = prepare_attrs_map( + [ + AttrDescr(NdAttrType.NDA_DST, NlAttrIp), + AttrDescr(NdAttrType.NDA_IFINDEX, NlAttrIfindex), + AttrDescr(NdAttrType.NDA_FLAGS_EXT, NlAttrU32), + AttrDescr(NdAttrType.NDA_LLADDR, NlAttrMac), + ] +) + + +class BaseNetlinkRtMessage(StdNetlinkMessage): + pass + + +class NetlinkRtMessage(BaseNetlinkRtMessage): + messages = [ + NlMsgProps(NlRtMsgType.RTM_NEWROUTE, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELROUTE, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETROUTE, NlMsgCategory.GET), + ] + nl_attrs_map = rtnl_route_attrs + + def __init__(self, helper, nlm_type): + super().__init__(helper, nlm_type) + self.base_hdr = RtMsgHdr() + + def parse_base_header(self, data): + if len(data) < sizeof(RtMsgHdr): + raise ValueError("length less than rtmsg header") + rtm_hdr = RtMsgHdr.from_buffer_copy(data) + return (rtm_hdr, sizeof(RtMsgHdr)) + + def print_base_header(self, hdr, prepend=""): + family = self.helper.get_af_name(hdr.rtm_family) + print( + "{}family={}, dst_len={}, src_len={}, tos={}, table={}, protocol={}({}), scope={}({}), type={}({}), flags={}({})".format( # noqa: E501 + prepend, + family, + hdr.rtm_dst_len, + hdr.rtm_src_len, + hdr.rtm_tos, + hdr.rtm_table, + self.helper.get_attr_byval(RtProto, hdr.rtm_protocol), + hdr.rtm_protocol, + self.helper.get_attr_byval(RtScope, hdr.rtm_scope), + hdr.rtm_scope, + self.helper.get_attr_byval(RtType, hdr.rtm_type), + hdr.rtm_type, + self.helper.get_bitmask_str(RtMsgFlags, hdr.rtm_flags), + hdr.rtm_flags, + ) + ) + + +class NetlinkIflaMessage(BaseNetlinkRtMessage): + messages = [ + NlMsgProps(NlRtMsgType.RTM_NEWLINK, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELLINK, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETLINK, NlMsgCategory.GET), + ] + nl_attrs_map = rtnl_ifla_attrs + + def __init__(self, helper, nlm_type): + super().__init__(helper, nlm_type) + self.base_hdr = IfinfoMsg() + + def parse_base_header(self, data): + if len(data) < sizeof(IfinfoMsg): + raise ValueError("length less than IfinfoMsg header") + rtm_hdr = IfinfoMsg.from_buffer_copy(data) + return (rtm_hdr, sizeof(IfinfoMsg)) + + def print_base_header(self, hdr, prepend=""): + family = self.helper.get_af_name(hdr.ifi_family) + print( + "{}family={}, ifi_type={}, ifi_index={}, ifi_flags={}, ifi_change={}".format( # noqa: E501 + prepend, + family, + hdr.ifi_type, + hdr.ifi_index, + hdr.ifi_flags, + hdr.ifi_change, + ) + ) + + +class NetlinkIfaMessage(BaseNetlinkRtMessage): + messages = [ + NlMsgProps(NlRtMsgType.RTM_NEWADDR, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELADDR, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETADDR, NlMsgCategory.GET), + ] + nl_attrs_map = rtnl_ifa_attrs + + def __init__(self, helper, nlm_type): + super().__init__(helper, nlm_type) + self.base_hdr = IfaddrMsg() + + def parse_base_header(self, data): + if len(data) < sizeof(IfaddrMsg): + raise ValueError("length less than IfaddrMsg header") + rtm_hdr = IfaddrMsg.from_buffer_copy(data) + return (rtm_hdr, sizeof(IfaddrMsg)) + + def print_base_header(self, hdr, prepend=""): + family = self.helper.get_af_name(hdr.ifa_family) + print( + "{}family={}, ifa_prefixlen={}, ifa_flags={}, ifa_scope={}, ifa_index={}".format( # noqa: E501 + prepend, + family, + hdr.ifa_prefixlen, + hdr.ifa_flags, + hdr.ifa_scope, + hdr.ifa_index, + ) + ) + + +class NetlinkNdMessage(BaseNetlinkRtMessage): + messages = [ + NlMsgProps(NlRtMsgType.RTM_NEWNEIGH, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELNEIGH, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETNEIGH, NlMsgCategory.GET), + ] + nl_attrs_map = rtnl_nd_attrs + + def __init__(self, helper, nlm_type): + super().__init__(helper, nlm_type) + self.base_hdr = NdMsg() + + def parse_base_header(self, data): + if len(data) < sizeof(NdMsg): + raise ValueError("length less than NdMsg header") + nd_hdr = NdMsg.from_buffer_copy(data) + return (nd_hdr, sizeof(NdMsg)) + + def print_base_header(self, hdr, prepend=""): + family = self.helper.get_af_name(hdr.ndm_family) + print( + "{}family={}, ndm_ifindex={}, ndm_state={}, ndm_flags={}".format( # noqa: E501 + prepend, + family, + hdr.ndm_ifindex, + hdr.ndm_state, + hdr.ndm_flags, + ) + ) + + +handler_classes = { + "netlink_route": [ + NetlinkRtMessage, + NetlinkIflaMessage, + NetlinkIfaMessage, + NetlinkNdMessage, + ], +} diff --git a/tests/atf_python/sys/netlink/utils.py b/tests/atf_python/sys/netlink/utils.py new file mode 100644 index 000000000000..f1d0ba3321ed --- /dev/null +++ b/tests/atf_python/sys/netlink/utils.py @@ -0,0 +1,80 @@ +#!/usr/local/bin/python3 +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple + + +class NlConst: + AF_NETLINK = 38 + NETLINK_ROUTE = 0 + NETLINK_GENERIC = 16 + GENL_ID_CTRL = 16 + + +def roundup2(val: int, num: int) -> int: + if val % num: + return (val | (num - 1)) + 1 + else: + return val + + +def align4(val: int) -> int: + return roundup2(val, 4) + + +def enum_or_int(val) -> int: + if isinstance(val, Enum): + return val.value + return val + + +class AttrDescr(NamedTuple): + val: Enum + cls: "NlAttr" + child_map: Any = None + is_array: bool = False + + +def prepare_attrs_map(attrs: List[AttrDescr]) -> Dict[str, Dict]: + ret = {} + for ad in attrs: + ret[ad.val.value] = {"ad": ad} + if ad.child_map: + ret[ad.val.value]["child"] = prepare_attrs_map(ad.child_map) + ret[ad.val.value]["is_array"] = ad.is_array + return ret + + +def build_propmap(cls): + ret = {} + for prop in dir(cls): + if not prop.startswith("_"): + ret[getattr(cls, prop).value] = prop + return ret + + +def get_bitmask_map(propmap, val): + v = 1 + ret = {} + while val: + if v & val: + if v in propmap: + ret[v] = propmap[v] + else: + ret[v] = hex(v) + val -= v + v *= 2 + return ret + + +def get_bitmask_str(cls, val): + if isinstance(cls, type): + pmap = build_propmap(cls) + else: + pmap = {} + for _cls in cls: + pmap.update(build_propmap(_cls)) + bmap = get_bitmask_map(pmap, val) + return ",".join([v for k, v in bmap.items()]) diff --git a/tests/atf_python/sys/netpfil/Makefile b/tests/atf_python/sys/netpfil/Makefile new file mode 100644 index 000000000000..47e7a0d4d4f1 --- /dev/null +++ b/tests/atf_python/sys/netpfil/Makefile @@ -0,0 +1,12 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +PACKAGE=tests +FILES= __init__.py +SUBDIR= ipfw + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python/sys/netpfil + +.include <bsd.prog.mk> diff --git a/tests/atf_python/sys/netpfil/__init__.py b/tests/atf_python/sys/netpfil/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 --- /dev/null +++ b/tests/atf_python/sys/netpfil/__init__.py diff --git a/tests/atf_python/sys/netpfil/ipfw/Makefile b/tests/atf_python/sys/netpfil/ipfw/Makefile new file mode 100644 index 000000000000..fde36de23c93 --- /dev/null +++ b/tests/atf_python/sys/netpfil/ipfw/Makefile @@ -0,0 +1,13 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +PACKAGE=tests +FILES= __init__.py insns.py insn_headers.py ioctl.py ioctl_headers.py \ + ipfw.py utils.py + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python/sys/netpfil/ipfw + +.include <bsd.prog.mk> + diff --git a/tests/atf_python/sys/netpfil/ipfw/__init__.py b/tests/atf_python/sys/netpfil/ipfw/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 --- /dev/null +++ b/tests/atf_python/sys/netpfil/ipfw/__init__.py diff --git a/tests/atf_python/sys/netpfil/ipfw/insn_headers.py b/tests/atf_python/sys/netpfil/ipfw/insn_headers.py new file mode 100644 index 000000000000..5c160d0758d6 --- /dev/null +++ b/tests/atf_python/sys/netpfil/ipfw/insn_headers.py @@ -0,0 +1,198 @@ +from enum import Enum + + +class IpFwOpcode(Enum): + O_NOP = 0 + O_IP_SRC = 1 + O_IP_SRC_MASK = 2 + O_IP_SRC_ME = 3 + O_IP_SRC_SET = 4 + O_IP_DST = 5 + O_IP_DST_MASK = 6 + O_IP_DST_ME = 7 + O_IP_DST_SET = 8 + O_IP_SRCPORT = 9 + O_IP_DSTPORT = 10 + O_PROTO = 11 + O_MACADDR2 = 12 + O_MAC_TYPE = 13 + O_LAYER2 = 14 + O_IN = 15 + O_FRAG = 16 + O_RECV = 17 + O_XMIT = 18 + O_VIA = 19 + O_IPOPT = 20 + O_IPLEN = 21 + O_IPID = 22 + O_IPTOS = 23 + O_IPPRECEDENCE = 24 + O_IPTTL = 25 + O_IPVER = 26 + O_UID = 27 + O_GID = 28 + O_ESTAB = 29 + O_TCPFLAGS = 30 + O_TCPWIN = 31 + O_TCPSEQ = 32 + O_TCPACK = 33 + O_ICMPTYPE = 34 + O_TCPOPTS = 35 + O_VERREVPATH = 36 + O_VERSRCREACH = 37 + O_PROBE_STATE = 38 + O_KEEP_STATE = 39 + O_LIMIT = 40 + O_LIMIT_PARENT = 41 + O_LOG = 42 + O_PROB = 43 + O_CHECK_STATE = 44 + O_ACCEPT = 45 + O_DENY = 46 + O_REJECT = 47 + O_COUNT = 48 + O_SKIPTO = 49 + O_PIPE = 50 + O_QUEUE = 51 + O_DIVERT = 52 + O_TEE = 53 + O_FORWARD_IP = 54 + O_FORWARD_MAC = 55 + O_NAT = 56 + O_REASS = 57 + O_IPSEC = 58 + O_IP_SRC_LOOKUP = 59 + O_IP_DST_LOOKUP = 60 + O_ANTISPOOF = 61 + O_JAIL = 62 + O_ALTQ = 63 + O_DIVERTED = 64 + O_TCPDATALEN = 65 + O_IP6_SRC = 66 + O_IP6_SRC_ME = 67 + O_IP6_SRC_MASK = 68 + O_IP6_DST = 69 + O_IP6_DST_ME = 70 + O_IP6_DST_MASK = 71 + O_FLOW6ID = 72 + O_ICMP6TYPE = 73 + O_EXT_HDR = 74 + O_IP6 = 75 + O_NETGRAPH = 76 + O_NGTEE = 77 + O_IP4 = 78 + O_UNREACH6 = 79 + O_TAG = 80 + O_TAGGED = 81 + O_SETFIB = 82 + O_FIB = 83 + O_SOCKARG = 84 + O_CALLRETURN = 85 + O_FORWARD_IP6 = 86 + O_DSCP = 87 + O_SETDSCP = 88 + O_IP_FLOW_LOOKUP = 89 + O_EXTERNAL_ACTION = 90 + O_EXTERNAL_INSTANCE = 91 + O_EXTERNAL_DATA = 92 + O_SKIP_ACTION = 93 + O_TCPMSS = 94 + O_MAC_SRC_LOOKUP = 95 + O_MAC_DST_LOOKUP = 96 + O_SETMARK = 97 + O_MARK = 98 + O_LAST_OPCODE = 99 + + +class Op3CmdType(Enum): + IP_FW_TABLE_XADD = 86 + IP_FW_TABLE_XDEL = 87 + IP_FW_TABLE_XGETSIZE = 88 + IP_FW_TABLE_XLIST = 89 + IP_FW_TABLE_XDESTROY = 90 + IP_FW_TABLES_XLIST = 92 + IP_FW_TABLE_XINFO = 93 + IP_FW_TABLE_XFLUSH = 94 + IP_FW_TABLE_XCREATE = 95 + IP_FW_TABLE_XMODIFY = 96 + IP_FW_XGET = 97 + IP_FW_XADD = 98 + IP_FW_XDEL = 99 + IP_FW_XMOVE = 100 + IP_FW_XZERO = 101 + IP_FW_XRESETLOG = 102 + IP_FW_SET_SWAP = 103 + IP_FW_SET_MOVE = 104 + IP_FW_SET_ENABLE = 105 + IP_FW_TABLE_XFIND = 106 + IP_FW_XIFLIST = 107 + IP_FW_TABLES_ALIST = 108 + IP_FW_TABLE_XSWAP = 109 + IP_FW_TABLE_VLIST = 110 + IP_FW_NAT44_XCONFIG = 111 + IP_FW_NAT44_DESTROY = 112 + IP_FW_NAT44_XGETCONFIG = 113 + IP_FW_NAT44_LIST_NAT = 114 + IP_FW_NAT44_XGETLOG = 115 + IP_FW_DUMP_SOPTCODES = 116 + IP_FW_DUMP_SRVOBJECTS = 117 + IP_FW_NAT64STL_CREATE = 130 + IP_FW_NAT64STL_DESTROY = 131 + IP_FW_NAT64STL_CONFIG = 132 + IP_FW_NAT64STL_LIST = 133 + IP_FW_NAT64STL_STATS = 134 + IP_FW_NAT64STL_RESET_STATS = 135 + IP_FW_NAT64LSN_CREATE = 140 + IP_FW_NAT64LSN_DESTROY = 141 + IP_FW_NAT64LSN_CONFIG = 142 + IP_FW_NAT64LSN_LIST = 143 + IP_FW_NAT64LSN_STATS = 144 + IP_FW_NAT64LSN_LIST_STATES = 145 + IP_FW_NAT64LSN_RESET_STATS = 146 + IP_FW_NPTV6_CREATE = 150 + IP_FW_NPTV6_DESTROY = 151 + IP_FW_NPTV6_CONFIG = 152 + IP_FW_NPTV6_LIST = 153 + IP_FW_NPTV6_STATS = 154 + IP_FW_NPTV6_RESET_STATS = 155 + IP_FW_NAT64CLAT_CREATE = 160 + IP_FW_NAT64CLAT_DESTROY = 161 + IP_FW_NAT64CLAT_CONFIG = 162 + IP_FW_NAT64CLAT_LIST = 163 + IP_FW_NAT64CLAT_STATS = 164 + IP_FW_NAT64CLAT_RESET_STATS = 165 + + +class IcmpRejectCode(Enum): + ICMP_UNREACH_NET = 0 + ICMP_UNREACH_HOST = 1 + ICMP_UNREACH_PROTOCOL = 2 + ICMP_UNREACH_PORT = 3 + ICMP_UNREACH_NEEDFRAG = 4 + ICMP_UNREACH_SRCFAIL = 5 + ICMP_UNREACH_NET_UNKNOWN = 6 + ICMP_UNREACH_HOST_UNKNOWN = 7 + ICMP_UNREACH_ISOLATED = 8 + ICMP_UNREACH_NET_PROHIB = 9 + ICMP_UNREACH_HOST_PROHIB = 10 + ICMP_UNREACH_TOSNET = 11 + ICMP_UNREACH_TOSHOST = 12 + ICMP_UNREACH_FILTER_PROHIB = 13 + ICMP_UNREACH_HOST_PRECEDENCE = 14 + ICMP_UNREACH_PRECEDENCE_CUTOFF = 15 + ICMP_REJECT_RST = 256 + ICMP_REJECT_ABORT = 257 + + +class Icmp6RejectCode(Enum): + ICMP6_DST_UNREACH_NOROUTE = 0 + ICMP6_DST_UNREACH_ADMIN = 1 + ICMP6_DST_UNREACH_BEYONDSCOPE = 2 + ICMP6_DST_UNREACH_NOTNEIGHBOR = 2 + ICMP6_DST_UNREACH_ADDR = 3 + ICMP6_DST_UNREACH_NOPORT = 4 + ICMP6_DST_UNREACH_POLICY = 5 + ICMP6_DST_UNREACH_REJECT = 6 + ICMP6_DST_UNREACH_SRCROUTE = 7 + ICMP6_UNREACH_RST = 256 + ICMP6_UNREACH_ABORT = 257 diff --git a/tests/atf_python/sys/netpfil/ipfw/insns.py b/tests/atf_python/sys/netpfil/ipfw/insns.py new file mode 100644 index 000000000000..f8a56de901ae --- /dev/null +++ b/tests/atf_python/sys/netpfil/ipfw/insns.py @@ -0,0 +1,558 @@ +#!/usr/bin/env python3 +import os +import socket +import struct +import subprocess +import sys +from ctypes import c_byte +from ctypes import c_char +from ctypes import c_int +from ctypes import c_long +from ctypes import c_uint32 +from ctypes import c_uint8 +from ctypes import c_ulong +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Union + +from atf_python.sys.netpfil.ipfw.insn_headers import IpFwOpcode +from atf_python.sys.netpfil.ipfw.insn_headers import IcmpRejectCode +from atf_python.sys.netpfil.ipfw.insn_headers import Icmp6RejectCode +from atf_python.sys.netpfil.ipfw.utils import AttrDescr +from atf_python.sys.netpfil.ipfw.utils import enum_or_int +from atf_python.sys.netpfil.ipfw.utils import enum_from_int +from atf_python.sys.netpfil.ipfw.utils import prepare_attrs_map + + +insn_actions = ( + IpFwOpcode.O_CHECK_STATE.value, + IpFwOpcode.O_REJECT.value, + IpFwOpcode.O_UNREACH6.value, + IpFwOpcode.O_ACCEPT.value, + IpFwOpcode.O_DENY.value, + IpFwOpcode.O_COUNT.value, + IpFwOpcode.O_NAT.value, + IpFwOpcode.O_QUEUE.value, + IpFwOpcode.O_PIPE.value, + IpFwOpcode.O_SKIPTO.value, + IpFwOpcode.O_NETGRAPH.value, + IpFwOpcode.O_NGTEE.value, + IpFwOpcode.O_DIVERT.value, + IpFwOpcode.O_TEE.value, + IpFwOpcode.O_CALLRETURN.value, + IpFwOpcode.O_FORWARD_IP.value, + IpFwOpcode.O_FORWARD_IP6.value, + IpFwOpcode.O_SETFIB.value, + IpFwOpcode.O_SETDSCP.value, + IpFwOpcode.O_REASS.value, + IpFwOpcode.O_SETMARK.value, + IpFwOpcode.O_EXTERNAL_ACTION.value, +) + + +class IpFwInsn(Structure): + _fields_ = [ + ("opcode", c_uint8), + ("length", c_uint8), + ("arg1", c_ushort), + ] + + +class BaseInsn(object): + obj_enum_class = IpFwOpcode + + def __init__(self, opcode, is_or, is_not, arg1): + if isinstance(opcode, Enum): + self.obj_type = opcode.value + self._enum = opcode + else: + self.obj_type = opcode + self._enum = enum_from_int(self.obj_enum_class, self.obj_type) + self.is_or = is_or + self.is_not = is_not + self.arg1 = arg1 + self.is_action = self.obj_type in insn_actions + self.ilen = 1 + self.obj_list = [] + + @property + def obj_name(self): + if self._enum is not None: + return self._enum.name + else: + return "opcode#{}".format(self.obj_type) + + @staticmethod + def get_insn_len(data: bytes) -> int: + (opcode_len,) = struct.unpack("@B", data[1:2]) + return opcode_len & 0x3F + + @classmethod + def _validate_len(cls, data, valid_options=None): + if len(data) < 4: + raise ValueError("opcode too short") + opcode_type, opcode_len = struct.unpack("@BB", data[:2]) + if len(data) != ((opcode_len & 0x3F) * 4): + raise ValueError("wrong length") + if valid_options and len(data) not in valid_options: + raise ValueError( + "len {} not in {} for {}".format( + len(data), valid_options, + enum_from_int(cls.obj_enum_class, data[0]) + ) + ) + + @classmethod + def _validate(cls, data): + cls._validate_len(data) + + @classmethod + def _parse(cls, data): + insn = IpFwInsn.from_buffer_copy(data[:4]) + is_or = (insn.length & 0x40) != 0 + is_not = (insn.length & 0x80) != 0 + return cls(opcode=insn.opcode, is_or=is_or, is_not=is_not, arg1=insn.arg1) + + @classmethod + def from_bytes(cls, data, attr_type_enum): + cls._validate(data) + opcode = cls._parse(data) + opcode._enum = attr_type_enum + return opcode + + def __bytes__(self): + raise NotImplementedError() + + def print_obj(self, prepend=""): + is_or = "" + if self.is_or: + is_or = " [OR]\\" + is_not = "" + if self.is_not: + is_not = "[!] " + print( + "{}{}len={} type={}({}){}{}".format( + prepend, + is_not, + len(bytes(self)), + self.obj_name, + self.obj_type, + self._print_obj_value(), + is_or, + ) + ) + + def _print_obj_value(self): + raise NotImplementedError() + + def print_obj_hex(self, prepend=""): + print(prepend) + print() + print(" ".join(["x{:02X}".format(b) for b in bytes(self)])) + + @staticmethod + def parse_insns(data, attr_map): + ret = [] + off = 0 + while off + sizeof(IpFwInsn) <= len(data): + hdr = IpFwInsn.from_buffer_copy(data[off : off + sizeof(IpFwInsn)]) + insn_len = (hdr.length & 0x3F) * 4 + if off + insn_len > len(data): + raise ValueError("wrng length") + # print("GET insn type {} len {}".format(hdr.opcode, insn_len)) + attr = attr_map.get(hdr.opcode, None) + if attr is None: + cls = InsnUnknown + type_enum = enum_from_int(BaseInsn.obj_enum_class, hdr.opcode) + else: + cls = attr["ad"].cls + type_enum = attr["ad"].val + insn = cls.from_bytes(data[off : off + insn_len], type_enum) + ret.append(insn) + off += insn_len + + if off != len(data): + raise ValueError("empty space") + return ret + + +class Insn(BaseInsn): + def __init__(self, opcode, is_or=False, is_not=False, arg1=0): + super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1) + + @classmethod + def _validate(cls, data): + cls._validate_len(data, [4]) + + def __bytes__(self): + length = self.ilen + if self.is_or: + length |= 0x40 + if self.is_not: + length | 0x80 + insn = IpFwInsn(opcode=self.obj_type, length=length, arg1=enum_or_int(self.arg1)) + return bytes(insn) + + def _print_obj_value(self): + return " arg1={}".format(self.arg1) + + +class InsnUnknown(Insn): + @classmethod + def _validate(cls, data): + cls._validate_len(data) + + @classmethod + def _parse(cls, data): + self = super()._parse(data) + self._data = data + return self + + def __bytes__(self): + return self._data + + def _print_obj_value(self): + return " " + " ".join(["x{:02X}".format(b) for b in self._data]) + + +class InsnEmpty(Insn): + @classmethod + def _validate(cls, data): + cls._validate_len(data, [4]) + insn = IpFwInsn.from_buffer_copy(data[:4]) + if insn.arg1 != 0: + raise ValueError("arg1 should be empty") + + def _print_obj_value(self): + return "" + + +class InsnComment(Insn): + def __init__(self, opcode=IpFwOpcode.O_NOP, is_or=False, is_not=False, arg1=0, comment=""): + super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1) + if comment: + self.comment = comment + else: + self.comment = "" + + @classmethod + def _validate(cls, data): + cls._validate_len(data) + if len(data) > 88: + raise ValueError("comment too long") + + @classmethod + def _parse(cls, data): + self = super()._parse(data) + # Comment encoding can be anything, + # use utf-8 to ease debugging + max_len = 0 + for b in range(4, len(data)): + if data[b] == b"\0": + break + max_len += 1 + self.comment = data[4:max_len].decode("utf-8") + return self + + def __bytes__(self): + ret = super().__bytes__() + comment_bytes = self.comment.encode("utf-8") + b"\0" + if len(comment_bytes) % 4 > 0: + comment_bytes += b"\0" * (4 - (len(comment_bytes) % 4)) + ret += comment_bytes + return ret + + def _print_obj_value(self): + return " comment='{}'".format(self.comment) + + +class InsnProto(Insn): + def __init__(self, opcode=IpFwOpcode.O_PROTO, is_or=False, is_not=False, arg1=0): + super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1) + + def _print_obj_value(self): + known_map = {6: "TCP", 17: "UDP", 41: "IPV6"} + proto = self.arg1 + if proto in known_map: + return " proto={}".format(known_map[proto]) + else: + return " proto=#{}".format(proto) + + +class InsnU32(Insn): + def __init__(self, opcode, is_or=False, is_not=False, arg1=0, u32=0): + super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1) + self.u32 = u32 + self.ilen = 2 + + @classmethod + def _validate(cls, data): + cls._validate_len(data, [8]) + + @classmethod + def _parse(cls, data): + self = super()._parse(data[:4]) + self.u32 = struct.unpack("@I", data[4:8])[0] + return self + + def __bytes__(self): + return super().__bytes__() + struct.pack("@I", self.u32) + + def _print_obj_value(self): + return " arg1={} u32={}".format(self.arg1, self.u32) + + +class InsnProb(InsnU32): + def __init__( + self, + opcode=IpFwOpcode.O_PROB, + is_or=False, + is_not=False, + arg1=0, + u32=0, + prob=0.0, + ): + super().__init__(opcode, is_or=is_or, is_not=is_not) + self.prob = prob + + @property + def prob(self): + return 1.0 * self.u32 / 0x7FFFFFFF + + @prob.setter + def prob(self, prob: float): + self.u32 = int(prob * 0x7FFFFFFF) + + def _print_obj_value(self): + return " prob={}".format(round(self.prob, 5)) + + +class InsnIp(InsnU32): + def __init__(self, opcode, is_or=False, is_not=False, arg1=0, u32=0, ip=None): + super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1, u32=u32) + if ip: + self.ip = ip + + @property + def ip(self): + return socket.inet_ntop(socket.AF_INET, struct.pack("@I", self.u32)) + + @ip.setter + def ip(self, ip: str): + ip_bin = socket.inet_pton(socket.AF_INET, ip) + self.u32 = struct.unpack("@I", ip_bin)[0] + + def _print_opcode_value(self): + return " ip={}".format(self.ip) + + +class InsnTable(Insn): + @classmethod + def _validate(cls, data): + cls._validate_len(data, [4, 8]) + + @classmethod + def _parse(cls, data): + self = super()._parse(data) + + if len(data) == 8: + (self.val,) = struct.unpack("@I", data[4:8]) + self.ilen = 2 + else: + self.val = None + return self + + def __bytes__(self): + ret = super().__bytes__() + if getattr(self, "val", None) is not None: + ret += struct.pack("@I", self.val) + return ret + + def _print_obj_value(self): + if getattr(self, "val", None) is not None: + return " table={} value={}".format(self.arg1, self.val) + else: + return " table={}".format(self.arg1) + + +class InsnReject(Insn): + def __init__(self, opcode, is_or=False, is_not=False, arg1=0, mtu=None): + super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1) + self.mtu = mtu + if self.mtu is not None: + self.ilen = 2 + + @classmethod + def _validate(cls, data): + cls._validate_len(data, [4, 8]) + + @classmethod + def _parse(cls, data): + self = super()._parse(data) + + if len(data) == 8: + (self.mtu,) = struct.unpack("@I", data[4:8]) + self.ilen = 2 + else: + self.mtu = None + return self + + def __bytes__(self): + ret = super().__bytes__() + if getattr(self, "mtu", None) is not None: + ret += struct.pack("@I", self.mtu) + return ret + + def _print_obj_value(self): + code = enum_from_int(IcmpRejectCode, self.arg1) + if getattr(self, "mtu", None) is not None: + return " code={} mtu={}".format(code, self.mtu) + else: + return " code={}".format(code) + + +class InsnPorts(Insn): + def __init__(self, opcode, is_or=False, is_not=False, arg1=0, port_pairs=[]): + super().__init__(opcode, is_or=is_or, is_not=is_not) + self.port_pairs = [] + if port_pairs: + self.port_pairs = port_pairs + + @classmethod + def _validate(cls, data): + if len(data) < 8: + raise ValueError("no ports specified") + cls._validate_len(data) + + @classmethod + def _parse(cls, data): + self = super()._parse(data) + + off = 4 + port_pairs = [] + while off + 4 <= len(data): + low, high = struct.unpack("@HH", data[off : off + 4]) + port_pairs.append((low, high)) + off += 4 + self.port_pairs = port_pairs + return self + + def __bytes__(self): + ret = super().__bytes__() + if getattr(self, "val", None) is not None: + ret += struct.pack("@I", self.val) + return ret + + def _print_obj_value(self): + ret = [] + for p in self.port_pairs: + if p[0] == p[1]: + ret.append(str(p[0])) + else: + ret.append("{}-{}".format(p[0], p[1])) + return " ports={}".format(",".join(ret)) + + +class IpFwInsnIp6(Structure): + _fields_ = [ + ("o", IpFwInsn), + ("addr6", c_byte * 16), + ("mask6", c_byte * 16), + ] + + +class InsnIp6(Insn): + def __init__(self, opcode, is_or=False, is_not=False, arg1=0, ip6=None, mask6=None): + super().__init__(opcode, is_or=is_or, is_not=is_not, arg1=arg1) + self.ip6 = ip6 + self.mask6 = mask6 + if mask6 is not None: + self.ilen = 9 + else: + self.ilen = 5 + + @classmethod + def _validate(cls, data): + cls._validate_len(data, [4 + 16, 4 + 16 * 2]) + + @classmethod + def _parse(cls, data): + self = super()._parse(data) + self.ip6 = socket.inet_ntop(socket.AF_INET6, data[4:20]) + + if len(data) == 4 + 16 * 2: + self.mask6 = socket.inet_ntop(socket.AF_INET6, data[20:36]) + self.ilen = 9 + else: + self.mask6 = None + self.ilen = 5 + return self + + def __bytes__(self): + ret = super().__bytes__() + socket.inet_pton(socket.AF_INET6, self.ip6) + if self.mask6 is not None: + ret += socket.inet_pton(socket.AF_INET6, self.mask6) + return ret + + def _print_obj_value(self): + if self.mask6: + return " ip6={}/{}".format(self.ip6, self.mask6) + else: + return " ip6={}".format(self.ip6) + + +insn_attrs = prepare_attrs_map( + [ + AttrDescr(IpFwOpcode.O_CHECK_STATE, InsnU32), + AttrDescr(IpFwOpcode.O_ACCEPT, InsnEmpty), + AttrDescr(IpFwOpcode.O_COUNT, InsnEmpty), + + AttrDescr(IpFwOpcode.O_REJECT, InsnReject), + AttrDescr(IpFwOpcode.O_UNREACH6, Insn), + AttrDescr(IpFwOpcode.O_DENY, InsnEmpty), + AttrDescr(IpFwOpcode.O_DIVERT, Insn), + AttrDescr(IpFwOpcode.O_COUNT, InsnEmpty), + AttrDescr(IpFwOpcode.O_QUEUE, Insn), + AttrDescr(IpFwOpcode.O_PIPE, Insn), + AttrDescr(IpFwOpcode.O_SKIPTO, InsnU32), + AttrDescr(IpFwOpcode.O_NETGRAPH, Insn), + AttrDescr(IpFwOpcode.O_NGTEE, Insn), + AttrDescr(IpFwOpcode.O_DIVERT, Insn), + AttrDescr(IpFwOpcode.O_TEE, Insn), + AttrDescr(IpFwOpcode.O_CALLRETURN, InsnU32), + AttrDescr(IpFwOpcode.O_SETFIB, Insn), + AttrDescr(IpFwOpcode.O_SETDSCP, Insn), + AttrDescr(IpFwOpcode.O_REASS, InsnEmpty), + AttrDescr(IpFwOpcode.O_SETMARK, InsnU32), + + AttrDescr(IpFwOpcode.O_EXTERNAL_ACTION, InsnU32), + AttrDescr(IpFwOpcode.O_EXTERNAL_INSTANCE, InsnU32), + + + + AttrDescr(IpFwOpcode.O_NOP, InsnComment), + AttrDescr(IpFwOpcode.O_PROTO, InsnProto), + AttrDescr(IpFwOpcode.O_PROB, InsnProb), + AttrDescr(IpFwOpcode.O_IP_DST_ME, InsnEmpty), + AttrDescr(IpFwOpcode.O_IP_SRC_ME, InsnEmpty), + AttrDescr(IpFwOpcode.O_IP6_DST_ME, InsnEmpty), + AttrDescr(IpFwOpcode.O_IP6_SRC_ME, InsnEmpty), + AttrDescr(IpFwOpcode.O_IP_SRC, InsnIp), + AttrDescr(IpFwOpcode.O_IP_DST, InsnIp), + AttrDescr(IpFwOpcode.O_IP6_DST, InsnIp6), + AttrDescr(IpFwOpcode.O_IP6_SRC, InsnIp6), + AttrDescr(IpFwOpcode.O_IP_SRC_LOOKUP, InsnU32), + AttrDescr(IpFwOpcode.O_IP_DST_LOOKUP, InsnU32), + AttrDescr(IpFwOpcode.O_IP_SRCPORT, InsnPorts), + AttrDescr(IpFwOpcode.O_IP_DSTPORT, InsnPorts), + AttrDescr(IpFwOpcode.O_PROBE_STATE, InsnU32), + AttrDescr(IpFwOpcode.O_KEEP_STATE, InsnU32), + ] +) diff --git a/tests/atf_python/sys/netpfil/ipfw/ioctl.py b/tests/atf_python/sys/netpfil/ipfw/ioctl.py new file mode 100644 index 000000000000..4c6d3f234c6c --- /dev/null +++ b/tests/atf_python/sys/netpfil/ipfw/ioctl.py @@ -0,0 +1,511 @@ +#!/usr/bin/env python3 +import os +import socket +import struct +import subprocess +import sys +from ctypes import c_byte +from ctypes import c_char +from ctypes import c_int +from ctypes import c_long +from ctypes import c_uint32 +from ctypes import c_uint8 +from ctypes import c_ulong +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Union + +import pytest +from atf_python.sys.netpfil.ipfw.insns import BaseInsn +from atf_python.sys.netpfil.ipfw.insns import insn_attrs +from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTableLookupType +from atf_python.sys.netpfil.ipfw.ioctl_headers import IpFwTlvType +from atf_python.sys.netpfil.ipfw.ioctl_headers import Op3CmdType +from atf_python.sys.netpfil.ipfw.utils import align8 +from atf_python.sys.netpfil.ipfw.utils import AttrDescr +from atf_python.sys.netpfil.ipfw.utils import enum_from_int +from atf_python.sys.netpfil.ipfw.utils import prepare_attrs_map + + +class IpFw3OpHeader(Structure): + _fields_ = [ + ("opcode", c_ushort), + ("version", c_ushort), + ("reserved1", c_ushort), + ("reserved2", c_ushort), + ] + + +class IpFwObjTlv(Structure): + _fields_ = [ + ("n_type", c_ushort), + ("flags", c_ushort), + ("length", c_uint32), + ] + + +class BaseTlv(object): + obj_enum_class = IpFwTlvType + + def __init__(self, obj_type): + if isinstance(obj_type, Enum): + self.obj_type = obj_type.value + self._enum = obj_type + else: + self.obj_type = obj_type + self._enum = enum_from_int(self.obj_enum_class, obj_type) + self.obj_list = [] + + def add_obj(self, obj): + self.obj_list.append(obj) + + @property + def len(self): + return len(bytes(self)) + + @property + def obj_name(self): + if self._enum is not None: + return self._enum.name + else: + return "tlv#{}".format(self.obj_type) + + def print_hdr(self, prepend=""): + print( + "{}len={} type={}({}){}".format( + prepend, self.len, self.obj_name, self.obj_type, self._print_obj_value() + ) + ) + + def print_obj(self, prepend=""): + self.print_hdr(prepend) + prepend = " " + prepend + for obj in self.obj_list: + obj.print_obj(prepend) + + def print_obj_hex(self, prepend=""): + print(prepend) + print() + print(" ".join(["x{:02X}".format(b) for b in bytes(self)])) + + @classmethod + def _validate(cls, data): + if len(data) < sizeof(IpFwObjTlv): + raise ValueError("TLV too short") + hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) + if len(data) != hdr.length: + raise ValueError("wrong TLV size") + + @classmethod + def _parse(cls, data, attr_map): + hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) + return cls(hdr.n_type) + + @classmethod + def from_bytes(cls, data, attr_map=None): + cls._validate(data) + obj = cls._parse(data, attr_map) + return obj + + def __bytes__(self): + raise NotImplementedError() + + def _print_obj_value(self): + return " " + " ".join( + ["x{:02X}".format(b) for b in self._data[sizeof(IpFwObjTlv) :]] + ) + + def as_hexdump(self): + return " ".join(["x{:02X}".format(b) for b in bytes(self)]) + + +class UnknownTlv(BaseTlv): + def __init__(self, obj_type, data): + super().__init__(obj_type) + self._data = data + + @classmethod + def _validate(cls, data): + if len(data) < sizeof(IpFwObjNTlv): + raise ValueError("TLV size is too short") + hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) + if len(data) != hdr.length: + raise ValueError("wrong TLV size") + + @classmethod + def _parse(cls, data, attr_map): + hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) + self = cls(hdr.n_type, data) + return self + + def __bytes__(self): + return self._data + + +class Tlv(BaseTlv): + @staticmethod + def parse_tlvs(data, attr_map): + # print("PARSING " + " ".join(["x{:02X}".format(b) for b in data])) + off = 0 + ret = [] + while off + sizeof(IpFwObjTlv) <= len(data): + hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)]) + if off + hdr.length > len(data): + raise ValueError("TLV size do not match") + obj_data = data[off : off + hdr.length] + obj_descr = attr_map.get(hdr.n_type, None) + if obj_descr is None: + # raise ValueError("unknown child TLV {}".format(hdr.n_type)) + cls = UnknownTlv + child_map = {} + else: + cls = obj_descr["ad"].cls + child_map = obj_descr.get("child", {}) + # print("FOUND OBJECT type {}".format(cls)) + # print() + obj = cls.from_bytes(obj_data, child_map) + ret.append(obj) + off += hdr.length + return ret + + +class IpFwObjNTlv(Structure): + _fields_ = [ + ("head", IpFwObjTlv), + ("idx", c_ushort), + ("n_set", c_uint8), + ("n_type", c_uint8), + ("spare", c_uint32), + ("name", c_char * 64), + ] + + +class NTlv(Tlv): + def __init__(self, obj_type, idx=0, n_set=0, n_type=0, name=None): + super().__init__(obj_type) + self.n_idx = idx + self.n_set = n_set + self.n_type = n_type + self.n_name = name + + @classmethod + def _validate(cls, data): + if len(data) != sizeof(IpFwObjNTlv): + raise ValueError("TLV size is not correct") + hdr = IpFwObjTlv.from_buffer_copy(data[: sizeof(IpFwObjTlv)]) + if len(data) != hdr.length: + raise ValueError("wrong TLV size") + + @classmethod + def _parse(cls, data, attr_map): + hdr = IpFwObjNTlv.from_buffer_copy(data[: sizeof(IpFwObjNTlv)]) + name = hdr.name.decode("utf-8") + self = cls(hdr.head.n_type, hdr.idx, hdr.n_set, hdr.n_type, name) + return self + + def __bytes__(self): + name_bytes = self.n_name.encode("utf-8") + if len(name_bytes) < 64: + name_bytes += b"\0" * (64 - len(name_bytes)) + hdr = IpFwObjNTlv( + head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)), + idx=self.n_idx, + n_set=self.n_set, + n_type=self.n_type, + name=name_bytes[:64], + ) + return bytes(hdr) + + def _print_obj_value(self): + return " " + "type={} set={} idx={} name={}".format( + self.n_type, self.n_set, self.n_idx, self.n_name + ) + + +class IpFwObjCTlv(Structure): + _fields_ = [ + ("head", IpFwObjTlv), + ("count", c_uint32), + ("objsize", c_ushort), + ("version", c_uint8), + ("flags", c_uint8), + ] + + +class CTlv(Tlv): + def __init__(self, obj_type, obj_list=[]): + super().__init__(obj_type) + if obj_list: + self.obj_list.extend(obj_list) + + @classmethod + def _validate(cls, data): + if len(data) < sizeof(IpFwObjCTlv): + raise ValueError("TLV too short") + hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)]) + if len(data) != hdr.head.length: + raise ValueError("wrong TLV size") + + @classmethod + def _parse(cls, data, attr_map): + hdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)]) + tlv_list = cls.parse_tlvs(data[sizeof(IpFwObjCTlv) :], attr_map) + if len(tlv_list) != hdr.count: + raise ValueError("wrong number of objects") + self = cls(hdr.head.n_type, obj_list=tlv_list) + return self + + def __bytes__(self): + ret = b"" + for obj in self.obj_list: + ret += bytes(obj) + length = len(ret) + sizeof(IpFwObjCTlv) + if self.obj_list: + objsize = len(bytes(self.obj_list[0])) + else: + objsize = 0 + hdr = IpFwObjCTlv( + head=IpFwObjTlv(n_type=self.obj_type, length=sizeof(IpFwObjNTlv)), + count=len(self.obj_list), + objsize=objsize, + ) + return bytes(hdr) + ret + + def _print_obj_value(self): + return "" + + +class IpFwRule(Structure): + _fields_ = [ + ("act_ofs", c_ushort), + ("cmd_len", c_ushort), + ("spare", c_ushort), + ("n_set", c_uint8), + ("flags", c_uint8), + ("rulenum", c_uint32), + ("n_id", c_uint32), + ] + + +class RawRule(Tlv): + def __init__(self, obj_type=0, n_set=0, rulenum=0, obj_list=[]): + super().__init__(obj_type) + self.n_set = n_set + self.rulenum = rulenum + if obj_list: + self.obj_list.extend(obj_list) + + @classmethod + def _validate(cls, data): + min_size = sizeof(IpFwRule) + if len(data) < min_size: + raise ValueError("rule TLV too short") + rule = IpFwRule.from_buffer_copy(data[:min_size]) + if len(data) != min_size + rule.cmd_len * 4: + raise ValueError("rule TLV cmd_len incorrect") + + @classmethod + def _parse(cls, data, attr_map): + hdr = IpFwRule.from_buffer_copy(data[: sizeof(IpFwRule)]) + self = cls( + n_set=hdr.n_set, + rulenum=hdr.rulenum, + obj_list=BaseInsn.parse_insns(data[sizeof(IpFwRule) :], insn_attrs), + ) + return self + + def __bytes__(self): + act_ofs = 0 + cmd_len = 0 + ret = b"" + for obj in self.obj_list: + if obj.is_action and act_ofs == 0: + act_ofs = cmd_len + obj_bytes = bytes(obj) + cmd_len += len(obj_bytes) // 4 + ret += obj_bytes + + hdr = IpFwRule( + act_ofs=act_ofs, + cmd_len=cmd_len, + n_set=self.n_set, + rulenum=self.rulenum, + ) + return bytes(hdr) + ret + + @property + def obj_name(self): + return "rule#{}".format(self.rulenum) + + def _print_obj_value(self): + cmd_len = sum([len(bytes(obj)) for obj in self.obj_list]) // 4 + return " set={} cmd_len={}".format(self.n_set, cmd_len) + + +class CTlvRule(CTlv): + def __init__(self, obj_type=IpFwTlvType.IPFW_TLV_RULE_LIST, obj_list=[]): + super().__init__(obj_type, obj_list) + + @classmethod + def _parse(cls, data, attr_map): + chdr = IpFwObjCTlv.from_buffer_copy(data[: sizeof(IpFwObjCTlv)]) + rule_list = [] + off = sizeof(IpFwObjCTlv) + while off + sizeof(IpFwRule) <= len(data): + hdr = IpFwRule.from_buffer_copy(data[off : off + sizeof(IpFwRule)]) + rule_len = sizeof(IpFwRule) + hdr.cmd_len * 4 + # print("FOUND RULE len={} cmd_len={}".format(rule_len, hdr.cmd_len)) + if off + rule_len > len(data): + raise ValueError("wrong rule size") + rule = RawRule.from_bytes(data[off : off + rule_len]) + rule_list.append(rule) + off += align8(rule_len) + if off != len(data): + raise ValueError("rule bytes left: off={} len={}".format(off, len(data))) + return cls(chdr.head.n_type, obj_list=rule_list) + + # XXX: _validate + + def __bytes__(self): + ret = b"" + for rule in self.obj_list: + rule_bytes = bytes(rule) + remainder = len(rule_bytes) % 8 + if remainder > 0: + rule_bytes += b"\0" * (8 - remainder) + ret += rule_bytes + hdr = IpFwObjCTlv( + head=IpFwObjTlv( + n_type=self.obj_type, length=len(ret) + sizeof(IpFwObjCTlv) + ), + count=len(self.obj_list), + ) + return bytes(hdr) + ret + + +class BaseIpFwMessage(object): + messages = [] + + def __init__(self, msg_type, obj_list=[]): + if isinstance(msg_type, Enum): + self.obj_type = msg_type.value + self._enum = msg_type + else: + self.obj_type = msg_type + self._enum = enum_from_int(self.messages, self.obj_type) + self.obj_list = [] + if obj_list: + self.obj_list.extend(obj_list) + + def add_obj(self, obj): + self.obj_list.append(obj) + + def get_obj(self, obj_type): + obj_type_raw = enum_or_int(obj_type) + for obj in self.obj_list: + if obj.obj_type == obj_type_raw: + return obj + return None + + @staticmethod + def parse_header(data: bytes): + if len(data) < sizeof(IpFw3OpHeader): + raise ValueError("length less than op3 message header") + return IpFw3OpHeader.from_buffer_copy(data), sizeof(IpFw3OpHeader) + + def parse_obj_list(self, data: bytes): + off = 0 + while off < len(data): + # print("PARSE off={} rem={}".format(off, len(data) - off)) + hdr = IpFwObjTlv.from_buffer_copy(data[off : off + sizeof(IpFwObjTlv)]) + # print(" tlv len {}".format(hdr.length)) + if hdr.length + off > len(data): + raise ValueError("TLV too big") + tlv = Tlv(hdr.n_type, data[off : off + hdr.length]) + self.add_obj(tlv) + off += hdr.length + + def is_type(self, msg_type): + return enum_or_int(msg_type) == self.msg_type + + @property + def obj_name(self): + if self._enum is not None: + return self._enum.name + else: + return "msg#{}".format(self.obj_type) + + def print_hdr(self, prepend=""): + print("{}len={}, type={}".format(prepend, len(bytes(self)), self.obj_name)) + + @classmethod + def from_bytes(cls, data): + try: + hdr, hdrlen = cls.parse_header(data) + self = cls(hdr.opcode) + self._orig_data = data + except ValueError as e: + print("Failed to parse op3 header: {}".format(e)) + cls.print_as_bytes(data) + raise + tlv_list = Tlv.parse_tlvs(data[hdrlen:], self.attr_map) + self.obj_list.extend(tlv_list) + return self + + def __bytes__(self): + ret = bytes(IpFw3OpHeader(opcode=self.obj_type)) + for obj in self.obj_list: + ret += bytes(obj) + return ret + + def print_obj(self): + self.print_hdr() + for obj in self.obj_list: + obj.print_obj(" ") + + @staticmethod + def print_as_bytes(data: bytes, descr: str): + print("===vv {} (len:{:3d}) vv===".format(descr, len(data))) + off = 0 + step = 16 + while off < len(data): + for i in range(step): + if off + i < len(data): + print(" {:02X}".format(data[off + i]), end="") + print("") + off += step + print("--------------------") + + +rule_attrs = prepare_attrs_map( + [ + AttrDescr( + IpFwTlvType.IPFW_TLV_TBLNAME_LIST, + CTlv, + [ + AttrDescr(IpFwTlvType.IPFW_TLV_TBL_NAME, NTlv), + AttrDescr(IpFwTlvType.IPFW_TLV_STATE_NAME, NTlv), + AttrDescr(IpFwTlvType.IPFW_TLV_EACTION, NTlv), + ], + True, + ), + AttrDescr(IpFwTlvType.IPFW_TLV_RULE_LIST, CTlvRule), + ] +) + + +class IpFwXRule(BaseIpFwMessage): + messages = [Op3CmdType.IP_FW_XADD] + attr_map = rule_attrs + + +legacy_classes = [] +set3_classes = [] +get3_classes = [IpFwXRule] diff --git a/tests/atf_python/sys/netpfil/ipfw/ioctl_headers.py b/tests/atf_python/sys/netpfil/ipfw/ioctl_headers.py new file mode 100644 index 000000000000..dc5c74bd1ad1 --- /dev/null +++ b/tests/atf_python/sys/netpfil/ipfw/ioctl_headers.py @@ -0,0 +1,90 @@ +from enum import Enum + + +class Op3CmdType(Enum): + IP_FW_TABLE_XADD = 86 + IP_FW_TABLE_XDEL = 87 + IP_FW_TABLE_XGETSIZE = 88 + IP_FW_TABLE_XLIST = 89 + IP_FW_TABLE_XDESTROY = 90 + IP_FW_TABLES_XLIST = 92 + IP_FW_TABLE_XINFO = 93 + IP_FW_TABLE_XFLUSH = 94 + IP_FW_TABLE_XCREATE = 95 + IP_FW_TABLE_XMODIFY = 96 + IP_FW_XGET = 97 + IP_FW_XADD = 98 + IP_FW_XDEL = 99 + IP_FW_XMOVE = 100 + IP_FW_XZERO = 101 + IP_FW_XRESETLOG = 102 + IP_FW_SET_SWAP = 103 + IP_FW_SET_MOVE = 104 + IP_FW_SET_ENABLE = 105 + IP_FW_TABLE_XFIND = 106 + IP_FW_XIFLIST = 107 + IP_FW_TABLES_ALIST = 108 + IP_FW_TABLE_XSWAP = 109 + IP_FW_TABLE_VLIST = 110 + IP_FW_NAT44_XCONFIG = 111 + IP_FW_NAT44_DESTROY = 112 + IP_FW_NAT44_XGETCONFIG = 113 + IP_FW_NAT44_LIST_NAT = 114 + IP_FW_NAT44_XGETLOG = 115 + IP_FW_DUMP_SOPTCODES = 116 + IP_FW_DUMP_SRVOBJECTS = 117 + IP_FW_NAT64STL_CREATE = 130 + IP_FW_NAT64STL_DESTROY = 131 + IP_FW_NAT64STL_CONFIG = 132 + IP_FW_NAT64STL_LIST = 133 + IP_FW_NAT64STL_STATS = 134 + IP_FW_NAT64STL_RESET_STATS = 135 + IP_FW_NAT64LSN_CREATE = 140 + IP_FW_NAT64LSN_DESTROY = 141 + IP_FW_NAT64LSN_CONFIG = 142 + IP_FW_NAT64LSN_LIST = 143 + IP_FW_NAT64LSN_STATS = 144 + IP_FW_NAT64LSN_LIST_STATES = 145 + IP_FW_NAT64LSN_RESET_STATS = 146 + IP_FW_NPTV6_CREATE = 150 + IP_FW_NPTV6_DESTROY = 151 + IP_FW_NPTV6_CONFIG = 152 + IP_FW_NPTV6_LIST = 153 + IP_FW_NPTV6_STATS = 154 + IP_FW_NPTV6_RESET_STATS = 155 + IP_FW_NAT64CLAT_CREATE = 160 + IP_FW_NAT64CLAT_DESTROY = 161 + IP_FW_NAT64CLAT_CONFIG = 162 + IP_FW_NAT64CLAT_LIST = 163 + IP_FW_NAT64CLAT_STATS = 164 + IP_FW_NAT64CLAT_RESET_STATS = 165 + + +class IpFwTableLookupType(Enum): + LOOKUP_DST_IP = 0 + LOOKUP_SRC_IP = 1 + LOOKUP_DST_PORT = 2 + LOOKUP_SRC_PORT = 3 + LOOKUP_UID = 4 + LOOKUP_JAIL = 5 + LOOKUP_DSCP = 6 + LOOKUP_DST_MAC = 7 + LOOKUP_SRC_MAC = 8 + LOOKUP_MARK = 9 + + +class IpFwTlvType(Enum): + IPFW_TLV_TBL_NAME = 1 + IPFW_TLV_TBLNAME_LIST = 2 + IPFW_TLV_RULE_LIST = 3 + IPFW_TLV_DYNSTATE_LIST = 4 + IPFW_TLV_TBL_ENT = 5 + IPFW_TLV_DYN_ENT = 6 + IPFW_TLV_RULE_ENT = 7 + IPFW_TLV_TBLENT_LIST = 8 + IPFW_TLV_RANGE = 9 + IPFW_TLV_EACTION = 10 + IPFW_TLV_COUNTERS = 11 + IPFW_TLV_OBJDATA = 12 + IPFW_TLV_STATE_NAME = 14 + IPFW_TLV_EACTION_BASE = 1000 diff --git a/tests/atf_python/sys/netpfil/ipfw/ipfw.py b/tests/atf_python/sys/netpfil/ipfw/ipfw.py new file mode 100644 index 000000000000..0bcc907eeab8 --- /dev/null +++ b/tests/atf_python/sys/netpfil/ipfw/ipfw.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +import os +import socket +import struct +import subprocess +import sys +from ctypes import c_byte +from ctypes import c_char +from ctypes import c_int +from ctypes import c_long +from ctypes import c_uint32 +from ctypes import c_uint8 +from ctypes import c_ulong +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Union + +from atf_python.sys.netpfil.ipfw.ioctl import get3_classes +from atf_python.sys.netpfil.ipfw.ioctl import legacy_classes +from atf_python.sys.netpfil.ipfw.ioctl import set3_classes +from atf_python.sys.netpfil.ipfw.utils import AttrDescr +from atf_python.sys.netpfil.ipfw.utils import enum_from_int +from atf_python.sys.netpfil.ipfw.utils import enum_or_int +from atf_python.sys.netpfil.ipfw.utils import prepare_attrs_map + + +class DebugHeader(Structure): + _fields_ = [ + ("cmd_type", c_ushort), + ("spare1", c_ushort), + ("opt_name", c_uint32), + ("total_len", c_uint32), + ("spare2", c_uint32), + ] + + +class DebugType(Enum): + DO_CMD = 1 + DO_SET3 = 2 + DO_GET3 = 3 + + +class DebugIoReader(object): + HANDLER_CLASSES = { + DebugType.DO_CMD: legacy_classes, + DebugType.DO_SET3: set3_classes, + DebugType.DO_GET3: get3_classes, + } + + def __init__(self, ipfw_path): + self._msgmap = self.build_msgmap() + self.ipfw_path = ipfw_path + + def build_msgmap(self): + xmap = {} + for debug_type, handler_classes in self.HANDLER_CLASSES.items(): + debug_type = enum_or_int(debug_type) + if debug_type not in xmap: + xmap[debug_type] = {} + for handler_class in handler_classes: + for msg in handler_class.messages: + xmap[debug_type][enum_or_int(msg)] = handler_class + return xmap + + def print_obj_header(self, hdr): + debug_type = "#{}".format(hdr.cmd_type) + for _type in self.HANDLER_CLASSES.keys(): + if _type.value == hdr.cmd_type: + debug_type = _type.name.lower() + break + print( + "@@ record for {} len={} optname={}".format( + debug_type, hdr.total_len, hdr.opt_name + ) + ) + + def parse_record(self, data): + hdr = DebugHeader.from_buffer_copy(data[: sizeof(DebugHeader)]) + data = data[sizeof(DebugHeader) :] + cls = self._msgmap[hdr.cmd_type].get(hdr.opt_name) + if cls is not None: + return cls.from_bytes(data) + raise ValueError( + "unsupported cmd_type={} opt_name={}".format(hdr.cmd_type, hdr.opt_name) + ) + + def get_record_from_stdin(self): + data = sys.stdin.buffer.peek(sizeof(DebugHeader)) + if len(data) == 0: + return None + + hdr = DebugHeader.from_buffer_copy(data) + data = sys.stdin.buffer.read(hdr.total_len) + return self.parse_record(data) + + def get_records_from_buffer(self, data): + off = 0 + ret = [] + while off + sizeof(DebugHeader) <= len(data): + hdr = DebugHeader.from_buffer_copy(data[off : off + sizeof(DebugHeader)]) + ret.append(self.parse_record(data[off : off + hdr.total_len])) + off += hdr.total_len + return ret + + def run_ipfw(self, cmd: str) -> bytes: + args = [self.ipfw_path, "-xqn"] + cmd.split() + r = subprocess.run(args, capture_output=True) + return r.stdout + + def get_records(self, cmd: str): + return self.get_records_from_buffer(self.run_ipfw(cmd)) diff --git a/tests/atf_python/sys/netpfil/ipfw/utils.py b/tests/atf_python/sys/netpfil/ipfw/utils.py new file mode 100644 index 000000000000..0b3e9570d216 --- /dev/null +++ b/tests/atf_python/sys/netpfil/ipfw/utils.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +import os +import socket +import struct +import subprocess +import sys +from enum import Enum +from typing import Dict +from typing import List +from typing import Optional +from typing import Union +from typing import Any +from typing import NamedTuple +import pytest + + +def roundup2(val: int, num: int) -> int: + if val % num: + return (val | (num - 1)) + 1 + else: + return val + + +def align8(val: int) -> int: + return roundup2(val, 8) + + +def enum_or_int(val) -> int: + if isinstance(val, Enum): + return val.value + return val + + +def enum_from_int(enum_class: Enum, val) -> Enum: + if isinstance(val, Enum): + return val + for item in enum_class: + if val == item.value: + return item + return None + + +class AttrDescr(NamedTuple): + val: Enum + cls: Any + child_map: Any = None + is_array: bool = False + + +def prepare_attrs_map(attrs: List[AttrDescr]) -> Dict[str, Dict]: + ret = {} + for ad in attrs: + ret[ad.val.value] = {"ad": ad} + if ad.child_map: + ret[ad.val.value]["child"] = prepare_attrs_map(ad.child_map) + ret[ad.val.value]["is_array"] = ad.is_array + return ret + + + diff --git a/tests/atf_python/utils.py b/tests/atf_python/utils.py new file mode 100644 index 000000000000..26911c12aef3 --- /dev/null +++ b/tests/atf_python/utils.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +import os +import pwd +from ctypes import CDLL +from ctypes import get_errno +from ctypes.util import find_library +from typing import Dict +from typing import List +from typing import Optional + +import pytest + + +def nodeid_to_method_name(nodeid: str) -> str: + """file_name.py::ClassName::method_name[parametrize] -> method_name""" + return nodeid.split("::")[-1].split("[")[0] + + +class LibCWrapper(object): + def __init__(self): + path: Optional[str] = find_library("c") + if path is None: + raise RuntimeError("libc not found") + self._libc = CDLL(path, use_errno=True) + + def modfind(self, mod_name: str) -> int: + if self._libc.modfind(bytes(mod_name, encoding="ascii")) == -1: + return get_errno() + return 0 + + def kldload(self, kld_name: str) -> int: + if self._libc.kldload(bytes(kld_name, encoding="ascii")) == -1: + return get_errno() + return 0 + + def jail_attach(self, jid: int) -> int: + if self._libc.jail_attach(jid) != 0: + return get_errno() + return 0 + + +libc = LibCWrapper() + + +class BaseTest(object): + NEED_ROOT: bool = False # True if the class needs root privileges for the setup + TARGET_USER = None # Set to the target user by the framework + REQUIRED_MODULES: List[str] = [] + SKIP_MODULES: List[str] = [] + + def require_module(self, mod_name: str, skip=True): + error_code = libc.modfind(mod_name) + if error_code == 0: + return + err_str = os.strerror(error_code) + txt = "kernel module '{}' not available: {}".format(mod_name, err_str) + if skip: + pytest.skip(txt) + else: + raise ValueError(txt) + + def skip_module(self, mod_name: str): + error_code = libc.modfind(mod_name) + if error_code == 0: + txt = "kernel module '{}' loaded, skip test".format(mod_name) + pytest.skip(txt) + return + + def _check_modules(self): + for mod_name in self.REQUIRED_MODULES: + self.require_module(mod_name) + for mod_name in self.SKIP_MODULES: + self.skip_module(mod_name) + + @property + def atf_vars(self) -> Dict[str, str]: + px = "_ATF_VAR_" + return {k[len(px):]: v for k, v in os.environ.items() if k.startswith(px)} + + def drop_privileges_user(self, user: str): + uid = pwd.getpwnam(user)[2] + print("Dropping privs to {}/{}".format(user, uid)) + os.setuid(uid) + + def drop_privileges(self): + if self.TARGET_USER: + if self.TARGET_USER == "unprivileged": + user = self.atf_vars["unprivileged-user"] + else: + user = self.TARGET_USER + self.drop_privileges_user(user) + + @property + def test_id(self) -> str: + # 'test_ip6_output.py::TestIP6Output::test_output6_pktinfo[ipandif] (setup)' + return os.environ.get("PYTEST_CURRENT_TEST").split(" ")[0] + + def setup_method(self, method): + """Run all pre-requisits for the test execution""" + self._check_modules() |