diff options
Diffstat (limited to 'tests/atf_python/sys/netlink')
-rw-r--r-- | tests/atf_python/sys/netlink/Makefile | 13 | ||||
-rw-r--r-- | tests/atf_python/sys/netlink/__init__.py | 0 | ||||
-rw-r--r-- | tests/atf_python/sys/netlink/attrs.py | 335 | ||||
-rw-r--r-- | tests/atf_python/sys/netlink/base_headers.py | 72 | ||||
-rw-r--r-- | tests/atf_python/sys/netlink/message.py | 286 | ||||
-rw-r--r-- | tests/atf_python/sys/netlink/netlink.py | 417 | ||||
-rw-r--r-- | tests/atf_python/sys/netlink/netlink_generic.py | 312 | ||||
-rw-r--r-- | tests/atf_python/sys/netlink/netlink_route.py | 832 | ||||
-rw-r--r-- | tests/atf_python/sys/netlink/utils.py | 80 |
9 files changed, 2347 insertions, 0 deletions
diff --git a/tests/atf_python/sys/netlink/Makefile b/tests/atf_python/sys/netlink/Makefile new file mode 100644 index 000000000000..6a40a93f3ae9 --- /dev/null +++ b/tests/atf_python/sys/netlink/Makefile @@ -0,0 +1,13 @@ +.include <src.opts.mk> + +.PATH: ${.CURDIR} + +PACKAGE=tests +FILES= __init__.py attrs.py base_headers.py message.py netlink.py \ + netlink_generic.py netlink_route.py utils.py + +.include <bsd.own.mk> +FILESDIR= ${TESTSBASE}/atf_python/sys/netlink + +.include <bsd.prog.mk> + diff --git a/tests/atf_python/sys/netlink/__init__.py b/tests/atf_python/sys/netlink/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 --- /dev/null +++ b/tests/atf_python/sys/netlink/__init__.py diff --git a/tests/atf_python/sys/netlink/attrs.py b/tests/atf_python/sys/netlink/attrs.py new file mode 100644 index 000000000000..36dd8191df1c --- /dev/null +++ b/tests/atf_python/sys/netlink/attrs.py @@ -0,0 +1,335 @@ +import socket +import struct +from enum import Enum + +from atf_python.sys.netlink.utils import align4 +from atf_python.sys.netlink.utils import enum_or_int + + +class NlAttr(object): + HDR_LEN = 4 # sizeof(struct nlattr) + + def __init__(self, nla_type, data): + if isinstance(nla_type, Enum): + self._nla_type = nla_type.value + self._enum = nla_type + else: + self._nla_type = nla_type + self._enum = None + self.nla_list = [] + self._data = data + + @property + def nla_type(self): + return self._nla_type & 0x3FFF + + @property + def nla_len(self): + return len(self._data) + 4 + + def add_nla(self, nla): + self.nla_list.append(nla) + + def print_attr(self, prepend=""): + if self._enum is not None: + type_str = self._enum.name + else: + type_str = "nla#{}".format(self.nla_type) + print( + "{}len={} type={}({}){}".format( + prepend, self.nla_len, type_str, self.nla_type, self._print_attr_value() + ) + ) + + @staticmethod + def _validate(data): + if len(data) < 4: + raise ValueError("attribute too short") + nla_len, nla_type = struct.unpack("@HH", data[:4]) + if nla_len > len(data): + raise ValueError("attribute length too big") + if nla_len < 4: + raise ValueError("attribute length too short") + + @classmethod + def _parse(cls, data): + nla_len, nla_type = struct.unpack("@HH", data[:4]) + return cls(nla_type, data[4:]) + + @classmethod + def from_bytes(cls, data, attr_type_enum=None): + cls._validate(data) + attr = cls._parse(data) + attr._enum = attr_type_enum + return attr + + def _to_bytes(self, data: bytes): + ret = data + if align4(len(ret)) != len(ret): + ret = data + bytes(align4(len(ret)) - len(ret)) + return struct.pack("@HH", len(data) + 4, self._nla_type) + ret + + def __bytes__(self): + return self._to_bytes(self._data) + + def _print_attr_value(self): + return " " + " ".join(["x{:02X}".format(b) for b in self._data]) + + +class NlAttrNested(NlAttr): + def __init__(self, nla_type, val): + super().__init__(nla_type, b"") + self.nla_list = val + + def get_nla(self, nla_type): + nla_type_raw = enum_or_int(nla_type) + for nla in self.nla_list: + if nla.nla_type == nla_type_raw: + return nla + return None + + @property + def nla_len(self): + return align4(len(b"".join([bytes(nla) for nla in self.nla_list]))) + 4 + + def print_attr(self, prepend=""): + if self._enum is not None: + type_str = self._enum.name + else: + type_str = "nla#{}".format(self.nla_type) + print( + "{}len={} type={}({}) {{".format( + prepend, self.nla_len, type_str, self.nla_type + ) + ) + for nla in self.nla_list: + nla.print_attr(prepend + " ") + print("{}}}".format(prepend)) + + def __bytes__(self): + return self._to_bytes(b"".join([bytes(nla) for nla in self.nla_list])) + + +class NlAttrU32(NlAttr): + def __init__(self, nla_type, val): + self.u32 = enum_or_int(val) + super().__init__(nla_type, b"") + + @property + def nla_len(self): + return 8 + + def _print_attr_value(self): + return " val={}".format(self.u32) + + @staticmethod + def _validate(data): + assert len(data) == 8 + nla_len, nla_type = struct.unpack("@HH", data[:4]) + assert nla_len == 8 + + @classmethod + def _parse(cls, data): + nla_len, nla_type, val = struct.unpack("@HHI", data) + return cls(nla_type, val) + + def __bytes__(self): + return self._to_bytes(struct.pack("@I", self.u32)) + + +class NlAttrS32(NlAttr): + def __init__(self, nla_type, val): + self.s32 = enum_or_int(val) + super().__init__(nla_type, b"") + + @property + def nla_len(self): + return 8 + + def _print_attr_value(self): + return " val={}".format(self.s32) + + @staticmethod + def _validate(data): + assert len(data) == 8 + nla_len, nla_type = struct.unpack("@HH", data[:4]) + assert nla_len == 8 + + @classmethod + def _parse(cls, data): + nla_len, nla_type, val = struct.unpack("@HHi", data) + return cls(nla_type, val) + + def __bytes__(self): + return self._to_bytes(struct.pack("@i", self.s32)) + + +class NlAttrU16(NlAttr): + def __init__(self, nla_type, val): + self.u16 = enum_or_int(val) + super().__init__(nla_type, b"") + + @property + def nla_len(self): + return 6 + + def _print_attr_value(self): + return " val={}".format(self.u16) + + @staticmethod + def _validate(data): + assert len(data) == 6 + nla_len, nla_type = struct.unpack("@HH", data[:4]) + assert nla_len == 6 + + @classmethod + def _parse(cls, data): + nla_len, nla_type, val = struct.unpack("@HHH", data) + return cls(nla_type, val) + + def __bytes__(self): + return self._to_bytes(struct.pack("@H", self.u16)) + + +class NlAttrU8(NlAttr): + def __init__(self, nla_type, val): + self.u8 = enum_or_int(val) + super().__init__(nla_type, b"") + + @property + def nla_len(self): + return 5 + + def _print_attr_value(self): + return " val={}".format(self.u8) + + @staticmethod + def _validate(data): + assert len(data) == 5 + nla_len, nla_type = struct.unpack("@HH", data[:4]) + assert nla_len == 5 + + @classmethod + def _parse(cls, data): + nla_len, nla_type, val = struct.unpack("@HHB", data) + return cls(nla_type, val) + + def __bytes__(self): + return self._to_bytes(struct.pack("@B", self.u8)) + + +class NlAttrIp(NlAttr): + def __init__(self, nla_type, addr: str): + super().__init__(nla_type, b"") + self.addr = addr + if ":" in self.addr: + self.family = socket.AF_INET6 + else: + self.family = socket.AF_INET + + @staticmethod + def _validate(data): + nla_len, nla_type = struct.unpack("@HH", data[:4]) + data_len = nla_len - 4 + if data_len != 4 and data_len != 16: + raise ValueError( + "Error validating attr {}: nla_len is not valid".format( # noqa: E501 + nla_type + ) + ) + + @property + def nla_len(self): + if self.family == socket.AF_INET6: + return 20 + else: + return 8 + return align4(len(self._data)) + 4 + + @classmethod + def _parse(cls, data): + nla_len, nla_type = struct.unpack("@HH", data[:4]) + data_len = len(data) - 4 + if data_len == 4: + addr = socket.inet_ntop(socket.AF_INET, data[4:8]) + else: + addr = socket.inet_ntop(socket.AF_INET6, data[4:20]) + return cls(nla_type, addr) + + def __bytes__(self): + return self._to_bytes(socket.inet_pton(self.family, self.addr)) + + def _print_attr_value(self): + return " addr={}".format(self.addr) + + +class NlAttrIp4(NlAttrIp): + def __init__(self, nla_type, addr: str): + super().__init__(nla_type, addr) + assert self.family == socket.AF_INET + + +class NlAttrIp6(NlAttrIp): + def __init__(self, nla_type, addr: str): + super().__init__(nla_type, addr) + assert self.family == socket.AF_INET6 + + +class NlAttrStr(NlAttr): + def __init__(self, nla_type, text): + super().__init__(nla_type, b"") + self.text = text + + @staticmethod + def _validate(data): + NlAttr._validate(data) + try: + data[4:].decode("utf-8") + except Exception as e: + raise ValueError("wrong utf-8 string: {}".format(e)) + + @property + def nla_len(self): + return len(self.text) + 5 + + @classmethod + def _parse(cls, data): + text = data[4:-1].decode("utf-8") + nla_len, nla_type = struct.unpack("@HH", data[:4]) + return cls(nla_type, text) + + def __bytes__(self): + return self._to_bytes(bytes(self.text, encoding="utf-8") + bytes(1)) + + def _print_attr_value(self): + return ' val="{}"'.format(self.text) + + +class NlAttrStrn(NlAttr): + def __init__(self, nla_type, text): + super().__init__(nla_type, b"") + self.text = text + + @staticmethod + def _validate(data): + NlAttr._validate(data) + try: + data[4:].decode("utf-8") + except Exception as e: + raise ValueError("wrong utf-8 string: {}".format(e)) + + @property + def nla_len(self): + return len(self.text) + 4 + + @classmethod + def _parse(cls, data): + text = data[4:].decode("utf-8") + nla_len, nla_type = struct.unpack("@HH", data[:4]) + return cls(nla_type, text) + + def __bytes__(self): + return self._to_bytes(bytes(self.text, encoding="utf-8")) + + def _print_attr_value(self): + return ' val="{}"'.format(self.text) diff --git a/tests/atf_python/sys/netlink/base_headers.py b/tests/atf_python/sys/netlink/base_headers.py new file mode 100644 index 000000000000..71771a249b3d --- /dev/null +++ b/tests/atf_python/sys/netlink/base_headers.py @@ -0,0 +1,72 @@ +from ctypes import c_ubyte +from ctypes import c_uint +from ctypes import c_ushort +from ctypes import Structure +from enum import Enum + + +class Nlmsghdr(Structure): + _fields_ = [ + ("nlmsg_len", c_uint), + ("nlmsg_type", c_ushort), + ("nlmsg_flags", c_ushort), + ("nlmsg_seq", c_uint), + ("nlmsg_pid", c_uint), + ] + + +class Nlattr(Structure): + _fields_ = [ + ("nla_len", c_ushort), + ("nla_type", c_ushort), + ] + + +class NlMsgType(Enum): + NLMSG_NOOP = 1 + NLMSG_ERROR = 2 + NLMSG_DONE = 3 + NLMSG_OVERRUN = 4 + + +class NlmBaseFlags(Enum): + NLM_F_REQUEST = 0x01 + NLM_F_MULTI = 0x02 + NLM_F_ACK = 0x04 + NLM_F_ECHO = 0x08 + NLM_F_DUMP_INTR = 0x10 + NLM_F_DUMP_FILTERED = 0x20 + + +# XXX: in python3.8 it is possible to +# class NlmGetFlags(Enum, NlmBaseFlags): + + +class NlmGetFlags(Enum): + NLM_F_ROOT = 0x100 + NLM_F_MATCH = 0x200 + NLM_F_ATOMIC = 0x400 + + +class NlmNewFlags(Enum): + NLM_F_REPLACE = 0x100 + NLM_F_EXCL = 0x200 + NLM_F_CREATE = 0x400 + NLM_F_APPEND = 0x800 + + +class NlmDeleteFlags(Enum): + NLM_F_NONREC = 0x100 + + +class NlmAckFlags(Enum): + NLM_F_CAPPED = 0x100 + NLM_F_ACK_TLVS = 0x200 + + +class GenlMsgHdr(Structure): + _fields_ = [ + ("cmd", c_ubyte), + ("version", c_ubyte), + ("reserved", c_ushort), + ] diff --git a/tests/atf_python/sys/netlink/message.py b/tests/atf_python/sys/netlink/message.py new file mode 100644 index 000000000000..98a1e3bb21c5 --- /dev/null +++ b/tests/atf_python/sys/netlink/message.py @@ -0,0 +1,286 @@ +#!/usr/local/bin/python3 +import struct +from ctypes import sizeof +from enum import Enum +from typing import List +from typing import NamedTuple + +from atf_python.sys.netlink.attrs import NlAttr +from atf_python.sys.netlink.attrs import NlAttrNested +from atf_python.sys.netlink.base_headers import NlmAckFlags +from atf_python.sys.netlink.base_headers import NlmNewFlags +from atf_python.sys.netlink.base_headers import NlmGetFlags +from atf_python.sys.netlink.base_headers import NlmDeleteFlags +from atf_python.sys.netlink.base_headers import NlmBaseFlags +from atf_python.sys.netlink.base_headers import Nlmsghdr +from atf_python.sys.netlink.base_headers import NlMsgType +from atf_python.sys.netlink.utils import align4 +from atf_python.sys.netlink.utils import enum_or_int +from atf_python.sys.netlink.utils import get_bitmask_str + + +class NlMsgCategory(Enum): + UNKNOWN = 0 + GET = 1 + NEW = 2 + DELETE = 3 + ACK = 4 + + +class NlMsgProps(NamedTuple): + msg: Enum + category: NlMsgCategory + + +class BaseNetlinkMessage(object): + def __init__(self, helper, nlmsg_type): + self.nlmsg_type = enum_or_int(nlmsg_type) + self.nla_list = [] + self._orig_data = None + self.helper = helper + self.nl_hdr = Nlmsghdr( + nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid + ) + self.base_hdr = None + + def set_request(self, need_ack=True): + self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST]) + if need_ack: + self.add_nlflags([NlmBaseFlags.NLM_F_ACK]) + + def add_nlflags(self, flags: List): + int_flags = 0 + for flag in flags: + int_flags |= enum_or_int(flag) + self.nl_hdr.nlmsg_flags |= int_flags + + def add_nla(self, nla): + self.nla_list.append(nla) + + def _get_nla(self, nla_list, nla_type): + nla_type_raw = enum_or_int(nla_type) + for nla in nla_list: + if nla.nla_type == nla_type_raw: + return nla + return None + + def get_nla(self, nla_type): + return self._get_nla(self.nla_list, nla_type) + + @staticmethod + def parse_nl_header(data: bytes): + if len(data) < sizeof(Nlmsghdr): + raise ValueError("length less than netlink message header") + return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr) + + def is_type(self, nlmsg_type): + nlmsg_type_raw = enum_or_int(nlmsg_type) + return nlmsg_type_raw == self.nl_hdr.nlmsg_type + + def is_reply(self, hdr): + return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value + + @property + def msg_name(self): + return "msg#{}".format(self._get_msg_type()) + + def _get_nl_category(self): + if self.is_reply(self.nl_hdr): + return NlMsgCategory.ACK + return NlMsgCategory.UNKNOWN + + def get_nlm_flags_str(self): + category = self._get_nl_category() + flags = self.nl_hdr.nlmsg_flags + + if category == NlMsgCategory.UNKNOWN: + return self.helper.get_bitmask_str(NlmBaseFlags, flags) + elif category == NlMsgCategory.GET: + flags_enum = NlmGetFlags + elif category == NlMsgCategory.NEW: + flags_enum = NlmNewFlags + elif category == NlMsgCategory.DELETE: + flags_enum = NlmDeleteFlags + elif category == NlMsgCategory.ACK: + flags_enum = NlmAckFlags + return get_bitmask_str([NlmBaseFlags, flags_enum], flags) + + def print_nl_header(self, prepend=""): + # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 + hdr = self.nl_hdr + print( + "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( + prepend, + hdr.nlmsg_len, + self.msg_name, + self.get_nlm_flags_str(), + hdr.nlmsg_flags, + hdr.nlmsg_seq, + hdr.nlmsg_pid, + ) + ) + + @classmethod + def from_bytes(cls, helper, data): + try: + hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) + self = cls(helper, hdr.nlmsg_type) + self._orig_data = data + self.nl_hdr = hdr + except ValueError as e: + print("Failed to parse nl header: {}".format(e)) + cls.print_as_bytes(data) + raise + return self + + def print_message(self): + self.print_nl_header() + + @staticmethod + def print_as_bytes(data: bytes, descr: str): + print("===vv {} (len:{:3d}) vv===".format(descr, len(data))) + off = 0 + step = 16 + while off < len(data): + for i in range(step): + if off + i < len(data): + print(" {:02X}".format(data[off + i]), end="") + print("") + off += step + print("--------------------") + + +class StdNetlinkMessage(BaseNetlinkMessage): + nl_attrs_map = {} + + @classmethod + def from_bytes(cls, helper, data): + try: + hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) + self = cls(helper, hdr.nlmsg_type) + self._orig_data = data + self.nl_hdr = hdr + except ValueError as e: + print("Failed to parse nl header: {}".format(e)) + cls.print_as_bytes(data) + raise + + offset = align4(hdrlen) + try: + base_hdr, hdrlen = self.parse_base_header(data[offset:]) + self.base_hdr = base_hdr + offset += align4(hdrlen) + # XXX: CAP_ACK + except ValueError as e: + print("Failed to parse nl rt header: {}".format(e)) + cls.print_as_bytes(data) + raise + + orig_offset = offset + try: + nla_list, nla_len = self.parse_nla_list(data[offset:]) + offset += nla_len + if offset != len(data): + raise ValueError( + "{} bytes left at the end of the packet".format(len(data) - offset) + ) # noqa: E501 + self.nla_list = nla_list + except ValueError as e: + print( + "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e) + ) # noqa: E501 + cls.print_as_bytes(data, "msg dump") + cls.print_as_bytes(data[orig_offset:], "failed block") + raise + return self + + def parse_child(self, data: bytes, attr_key, attr_map): + attrs, _ = self.parse_attrs(data, attr_map) + return NlAttrNested(attr_key, attrs) + + def parse_child_array(self, data: bytes, attr_key, attr_map): + ret = [] + off = 0 + while len(data) - off >= 4: + nla_len, raw_nla_type = struct.unpack("@HH", data[off : off + 4]) + if nla_len + off > len(data): + raise ValueError( + "attr length {} > than the remaining length {}".format( + nla_len, len(data) - off + ) + ) + nla_type = raw_nla_type & 0x3FFF + val = self.parse_child(data[off + 4 : off + nla_len], nla_type, attr_map) + ret.append(val) + off += align4(nla_len) + return NlAttrNested(attr_key, ret) + + def parse_attrs(self, data: bytes, attr_map): + ret = [] + off = 0 + while len(data) - off >= 4: + nla_len, raw_nla_type = struct.unpack("@HH", data[off : off + 4]) + if nla_len + off > len(data): + raise ValueError( + "attr length {} > than the remaining length {}".format( + nla_len, len(data) - off + ) + ) + nla_type = raw_nla_type & 0x3FFF + if nla_type in attr_map: + v = attr_map[nla_type] + val = v["ad"].cls.from_bytes(data[off : off + nla_len], v["ad"].val) + if "child" in v: + # nested + child_data = data[off + 4 : off + nla_len] + if v.get("is_array", False): + # Array of nested attributes + val = self.parse_child_array( + child_data, v["ad"].val, v["child"] + ) + else: + val = self.parse_child(child_data, v["ad"].val, v["child"]) + else: + # unknown attribute + val = NlAttr(raw_nla_type, data[off + 4 : off + nla_len]) + ret.append(val) + off += align4(nla_len) + return ret, off + + def parse_nla_list(self, data: bytes) -> List[NlAttr]: + return self.parse_attrs(data, self.nl_attrs_map) + + def __bytes__(self): + ret = bytes() + for nla in self.nla_list: + ret += bytes(nla) + ret = bytes(self.base_hdr) + ret + self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr) + return bytes(self.nl_hdr) + ret + + def _get_msg_type(self): + return self.nl_hdr.nlmsg_type + + @property + def msg_props(self): + msg_type = self._get_msg_type() + for msg_props in self.messages: + if msg_props.msg.value == msg_type: + return msg_props + return None + + @property + def msg_name(self): + msg_props = self.msg_props + if msg_props is not None: + return msg_props.msg.name + return super().msg_name + + def print_base_header(self, hdr, prepend=""): + pass + + def print_message(self): + self.print_nl_header() + self.print_base_header(self.base_hdr, " ") + for nla in self.nla_list: + nla.print_attr(" ") diff --git a/tests/atf_python/sys/netlink/netlink.py b/tests/atf_python/sys/netlink/netlink.py new file mode 100644 index 000000000000..f8f886b09b24 --- /dev/null +++ b/tests/atf_python/sys/netlink/netlink.py @@ -0,0 +1,417 @@ +#!/usr/local/bin/python3 +import os +import socket +import sys +from ctypes import c_int +from ctypes import c_ubyte +from ctypes import c_uint +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from enum import auto +from enum import Enum + +from atf_python.sys.netlink.attrs import NlAttr +from atf_python.sys.netlink.attrs import NlAttrStr +from atf_python.sys.netlink.attrs import NlAttrU32 +from atf_python.sys.netlink.base_headers import GenlMsgHdr +from atf_python.sys.netlink.base_headers import NlmBaseFlags +from atf_python.sys.netlink.base_headers import Nlmsghdr +from atf_python.sys.netlink.base_headers import NlMsgType +from atf_python.sys.netlink.message import BaseNetlinkMessage +from atf_python.sys.netlink.message import NlMsgCategory +from atf_python.sys.netlink.message import NlMsgProps +from atf_python.sys.netlink.message import StdNetlinkMessage +from atf_python.sys.netlink.netlink_generic import GenlCtrlAttrType +from atf_python.sys.netlink.netlink_generic import GenlCtrlMsgType +from atf_python.sys.netlink.netlink_generic import handler_classes as genl_classes +from atf_python.sys.netlink.netlink_route import handler_classes as rt_classes +from atf_python.sys.netlink.utils import align4 +from atf_python.sys.netlink.utils import AttrDescr +from atf_python.sys.netlink.utils import build_propmap +from atf_python.sys.netlink.utils import enum_or_int +from atf_python.sys.netlink.utils import get_bitmask_map +from atf_python.sys.netlink.utils import NlConst +from atf_python.sys.netlink.utils import prepare_attrs_map + + +class SockaddrNl(Structure): + _fields_ = [ + ("nl_len", c_ubyte), + ("nl_family", c_ubyte), + ("nl_pad", c_ushort), + ("nl_pid", c_uint), + ("nl_groups", c_uint), + ] + + +class Nlmsgdone(Structure): + _fields_ = [ + ("error", c_int), + ] + + +class Nlmsgerr(Structure): + _fields_ = [ + ("error", c_int), + ("msg", Nlmsghdr), + ] + + +class NlErrattrType(Enum): + NLMSGERR_ATTR_UNUSED = 0 + NLMSGERR_ATTR_MSG = auto() + NLMSGERR_ATTR_OFFS = auto() + NLMSGERR_ATTR_COOKIE = auto() + NLMSGERR_ATTR_POLICY = auto() + + +class AddressFamilyLinux(Enum): + AF_INET = socket.AF_INET + AF_INET6 = socket.AF_INET6 + AF_NETLINK = 16 + + +class AddressFamilyBsd(Enum): + AF_INET = socket.AF_INET + AF_INET6 = socket.AF_INET6 + AF_NETLINK = 38 + + +class NlHelper: + def __init__(self): + self._pmap = {} + self._af_cls = self.get_af_cls() + self._seq_counter = 1 + self.pid = os.getpid() + + def get_seq(self): + ret = self._seq_counter + self._seq_counter += 1 + return ret + + def get_af_cls(self): + if sys.platform.startswith("freebsd"): + cls = AddressFamilyBsd + else: + cls = AddressFamilyLinux + return cls + + def get_propmap(self, cls): + if cls not in self._pmap: + self._pmap[cls] = build_propmap(cls) + return self._pmap[cls] + + def get_name_propmap(self, cls): + ret = {} + for prop in dir(cls): + if not prop.startswith("_"): + ret[prop] = getattr(cls, prop).value + return ret + + def get_attr_byval(self, cls, attr_val): + propmap = self.get_propmap(cls) + return propmap.get(attr_val) + + def get_af_name(self, family): + v = self.get_attr_byval(self._af_cls, family) + if v is not None: + return v + return "af#{}".format(family) + + def get_af_value(self, family_str: str) -> int: + propmap = self.get_name_propmap(self._af_cls) + return propmap.get(family_str) + + def get_bitmask_str(self, cls, val): + bmap = get_bitmask_map(self.get_propmap(cls), val) + return ",".join([v for k, v in bmap.items()]) + + @staticmethod + def get_bitmask_str_uncached(cls, val): + pmap = NlHelper.build_propmap(cls) + bmap = NlHelper.get_bitmask_map(pmap, val) + return ",".join([v for k, v in bmap.items()]) + + +nldone_attrs = prepare_attrs_map([]) + +nlerr_attrs = prepare_attrs_map( + [ + AttrDescr(NlErrattrType.NLMSGERR_ATTR_MSG, NlAttrStr), + AttrDescr(NlErrattrType.NLMSGERR_ATTR_OFFS, NlAttrU32), + AttrDescr(NlErrattrType.NLMSGERR_ATTR_COOKIE, NlAttr), + ] +) + + +class NetlinkDoneMessage(StdNetlinkMessage): + messages = [NlMsgProps(NlMsgType.NLMSG_DONE, NlMsgCategory.ACK)] + nl_attrs_map = nldone_attrs + + @property + def error_code(self): + return self.base_hdr.error + + def parse_base_header(self, data): + if len(data) < sizeof(Nlmsgdone): + raise ValueError("length less than nlmsgdone header") + done_hdr = Nlmsgdone.from_buffer_copy(data) + sz = sizeof(Nlmsgdone) + return (done_hdr, sz) + + def print_base_header(self, hdr, prepend=""): + print("{}error={}".format(prepend, hdr.error)) + + +class NetlinkErrorMessage(StdNetlinkMessage): + messages = [NlMsgProps(NlMsgType.NLMSG_ERROR, NlMsgCategory.ACK)] + nl_attrs_map = nlerr_attrs + + @property + def error_code(self): + return self.base_hdr.error + + @property + def error_str(self): + nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_MSG) + if nla: + return nla.text + return None + + @property + def error_offset(self): + nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_OFFS) + if nla: + return nla.u32 + return None + + @property + def cookie(self): + return self.get_nla(NlErrattrType.NLMSGERR_ATTR_COOKIE) + + def parse_base_header(self, data): + if len(data) < sizeof(Nlmsgerr): + raise ValueError("length less than nlmsgerr header") + err_hdr = Nlmsgerr.from_buffer_copy(data) + sz = sizeof(Nlmsgerr) + if (self.nl_hdr.nlmsg_flags & 0x100) == 0: + sz += align4(err_hdr.msg.nlmsg_len - sizeof(Nlmsghdr)) + return (err_hdr, sz) + + def print_base_header(self, errhdr, prepend=""): + print("{}error={}, ".format(prepend, errhdr.error), end="") + hdr = errhdr.msg + print( + "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( + prepend, + hdr.nlmsg_len, + "msg#{}".format(hdr.nlmsg_type), + self.helper.get_bitmask_str(NlmBaseFlags, hdr.nlmsg_flags), + hdr.nlmsg_flags, + hdr.nlmsg_seq, + hdr.nlmsg_pid, + ) + ) + + +core_classes = { + "netlink_core": [ + NetlinkDoneMessage, + NetlinkErrorMessage, + ], +} + + +class Nlsock: + HANDLER_CLASSES = [core_classes, rt_classes, genl_classes] + + def __init__(self, family, helper): + self.helper = helper + self.sock_fd = self._setup_netlink(family) + self._sock_family = family + self._data = bytes() + self.msgmap = self.build_msgmap() + self._family_map = { + NlConst.GENL_ID_CTRL: "nlctrl", + } + + def build_msgmap(self): + handler_classes = {} + for d in self.HANDLER_CLASSES: + handler_classes.update(d) + xmap = {} + # 'family_name': [class.messages[MsgProps.msg], ] + for family_id, family_classes in handler_classes.items(): + xmap[family_id] = {} + for cls in family_classes: + for msg_props in cls.messages: + xmap[family_id][enum_or_int(msg_props.msg)] = cls + return xmap + + def _setup_netlink(self, netlink_family) -> int: + family = self.helper.get_af_value("AF_NETLINK") + s = socket.socket(family, socket.SOCK_RAW, netlink_family) + s.setsockopt(270, 10, 1) # NETLINK_CAP_ACK + s.setsockopt(270, 11, 1) # NETLINK_EXT_ACK + return s + + def set_groups(self, mask: int): + self.sock_fd.setsockopt(socket.SOL_SOCKET, 1, mask) + # snl = SockaddrNl(nl_len = sizeof(SockaddrNl), nl_family=38, + # nl_pid=self.pid, nl_groups=mask) + # xbuffer = create_string_buffer(sizeof(SockaddrNl)) + # memmove(xbuffer, addressof(snl), sizeof(SockaddrNl)) + # k = struct.pack("@BBHII", 12, 38, 0, self.pid, mask) + # self.sock_fd.bind(k) + + def join_group(self, group_id: int): + self.sock_fd.setsockopt(270, 1, group_id) + + def write_message(self, msg, verbose=True): + if verbose: + print("vvvvvvvv OUT vvvvvvvv") + msg.print_message() + msg_bytes = bytes(msg) + try: + ret = os.write(self.sock_fd.fileno(), msg_bytes) + assert ret == len(msg_bytes) + except Exception as e: + print("write({}) -> {}".format(len(msg_bytes), e)) + + def parse_message(self, data: bytes): + if len(data) < sizeof(Nlmsghdr): + raise Exception("Short read from nl: {} bytes".format(len(data))) + hdr = Nlmsghdr.from_buffer_copy(data) + if hdr.nlmsg_type < 16: + family_name = "netlink_core" + nlmsg_type = hdr.nlmsg_type + elif self._sock_family == NlConst.NETLINK_ROUTE: + family_name = "netlink_route" + nlmsg_type = hdr.nlmsg_type + else: + # Genetlink + if len(data) < sizeof(Nlmsghdr) + sizeof(GenlMsgHdr): + raise Exception("Short read from genl: {} bytes".format(len(data))) + family_name = self._family_map.get(hdr.nlmsg_type, "") + ghdr = GenlMsgHdr.from_buffer_copy(data[sizeof(Nlmsghdr):]) + nlmsg_type = ghdr.cmd + cls = self.msgmap.get(family_name, {}).get(nlmsg_type) + if not cls: + cls = BaseNetlinkMessage + return cls.from_bytes(self.helper, data) + + def get_genl_family_id(self, family_name): + hdr = Nlmsghdr( + nlmsg_type=NlConst.GENL_ID_CTRL, + nlmsg_flags=NlmBaseFlags.NLM_F_REQUEST.value, + nlmsg_seq=self.helper.get_seq(), + ) + ghdr = GenlMsgHdr(cmd=GenlCtrlMsgType.CTRL_CMD_GETFAMILY.value) + nla = NlAttrStr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, family_name) + hdr.nlmsg_len = sizeof(Nlmsghdr) + sizeof(GenlMsgHdr) + len(bytes(nla)) + + msg_bytes = bytes(hdr) + bytes(ghdr) + bytes(nla) + self.write_data(msg_bytes) + while True: + rx_msg = self.read_message() + if hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: + if rx_msg.is_type(NlMsgType.NLMSG_ERROR): + if rx_msg.error_code != 0: + raise ValueError("unable to get family {}".format(family_name)) + else: + family_id = rx_msg.get_nla(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID).u16 + self._family_map[family_id] = family_name + return family_id + raise ValueError("unable to get family {}".format(family_name)) + + def write_data(self, data: bytes): + self.sock_fd.send(data) + + def read_data(self): + while True: + data = self.sock_fd.recv(65535) + self._data += data + if len(self._data) >= sizeof(Nlmsghdr): + break + + def read_message(self) -> bytes: + if len(self._data) < sizeof(Nlmsghdr): + self.read_data() + hdr = Nlmsghdr.from_buffer_copy(self._data) + while hdr.nlmsg_len > len(self._data): + self.read_data() + raw_msg = self._data[: hdr.nlmsg_len] + self._data = self._data[hdr.nlmsg_len:] + return self.parse_message(raw_msg) + + def get_reply(self, tx_msg): + self.write_message(tx_msg) + while True: + rx_msg = self.read_message() + if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: + return rx_msg + + +class NetlinkMultipartIterator(object): + def __init__(self, obj, seq_number: int, msg_type): + self._obj = obj + self._seq = seq_number + self._msg_type = msg_type + + def __iter__(self): + return self + + def __next__(self): + msg = self._obj.read_message() + if self._seq != msg.nl_hdr.nlmsg_seq: + raise ValueError("bad sequence number") + if msg.is_type(NlMsgType.NLMSG_ERROR): + raise ValueError( + "error while handling multipart msg: {}".format(msg.error_code) + ) + elif msg.is_type(NlMsgType.NLMSG_DONE): + if msg.error_code == 0: + raise StopIteration + raise ValueError( + "error listing some parts of the multipart msg: {}".format( + msg.error_code + ) + ) + elif not msg.is_type(self._msg_type): + raise ValueError("bad message type: {}".format(msg)) + return msg + + +class NetlinkTestTemplate(object): + REQUIRED_MODULES = ["netlink"] + + def setup_netlink(self, netlink_family: NlConst): + self.helper = NlHelper() + self.nlsock = Nlsock(netlink_family, self.helper) + + def write_message(self, msg, silent=False): + if not silent: + print("") + print("============= >> TX MESSAGE =============") + msg.print_message() + msg.print_as_bytes(bytes(msg), "-- DATA --") + self.nlsock.write_data(bytes(msg)) + + def read_message(self, silent=False): + msg = self.nlsock.read_message() + if not silent: + print("") + print("============= << RX MESSAGE =============") + msg.print_message() + return msg + + def get_reply(self, tx_msg): + self.write_message(tx_msg) + while True: + rx_msg = self.read_message() + if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: + return rx_msg + + def read_msg_list(self, seq, msg_type): + return list(NetlinkMultipartIterator(self, seq, msg_type)) diff --git a/tests/atf_python/sys/netlink/netlink_generic.py b/tests/atf_python/sys/netlink/netlink_generic.py new file mode 100644 index 000000000000..80c6eea72a93 --- /dev/null +++ b/tests/atf_python/sys/netlink/netlink_generic.py @@ -0,0 +1,312 @@ +#!/usr/local/bin/python3 +import struct +from ctypes import c_int64 +from ctypes import c_long +from ctypes import sizeof +from ctypes import Structure +from enum import Enum + +from atf_python.sys.netlink.attrs import NlAttr +from atf_python.sys.netlink.attrs import NlAttrIp4 +from atf_python.sys.netlink.attrs import NlAttrIp6 +from atf_python.sys.netlink.attrs import NlAttrNested +from atf_python.sys.netlink.attrs import NlAttrS32 +from atf_python.sys.netlink.attrs import NlAttrStr +from atf_python.sys.netlink.attrs import NlAttrU16 +from atf_python.sys.netlink.attrs import NlAttrU32 +from atf_python.sys.netlink.attrs import NlAttrU8 +from atf_python.sys.netlink.base_headers import GenlMsgHdr +from atf_python.sys.netlink.message import NlMsgCategory +from atf_python.sys.netlink.message import NlMsgProps +from atf_python.sys.netlink.message import StdNetlinkMessage +from atf_python.sys.netlink.utils import AttrDescr +from atf_python.sys.netlink.utils import enum_or_int +from atf_python.sys.netlink.utils import prepare_attrs_map + + +class NetlinkGenlMessage(StdNetlinkMessage): + messages = [] + nl_attrs_map = {} + family_name = None + + def __init__(self, helper, family_id, cmd=0): + super().__init__(helper, family_id) + self.base_hdr = GenlMsgHdr(cmd=enum_or_int(cmd)) + + def parse_base_header(self, data): + if len(data) < sizeof(GenlMsgHdr): + raise ValueError("length less than GenlMsgHdr header") + ghdr = GenlMsgHdr.from_buffer_copy(data) + return (ghdr, sizeof(GenlMsgHdr)) + + def _get_msg_type(self): + return self.base_hdr.cmd + + def print_nl_header(self, prepend=""): + # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 + hdr = self.nl_hdr + print( + "{}len={}, family={}, flags={}(0x{:X}), seq={}, pid={}".format( + prepend, + hdr.nlmsg_len, + self.family_name, + self.get_nlm_flags_str(), + hdr.nlmsg_flags, + hdr.nlmsg_seq, + hdr.nlmsg_pid, + ) + ) + + def print_base_header(self, hdr, prepend=""): + print( + "{}cmd={} version={} reserved={}".format( + prepend, self.msg_name, hdr.version, hdr.reserved + ) + ) + + +GenlCtrlFamilyName = "nlctrl" + + +class GenlCtrlMsgType(Enum): + CTRL_CMD_UNSPEC = 0 + CTRL_CMD_NEWFAMILY = 1 + CTRL_CMD_DELFAMILY = 2 + CTRL_CMD_GETFAMILY = 3 + CTRL_CMD_NEWOPS = 4 + CTRL_CMD_DELOPS = 5 + CTRL_CMD_GETOPS = 6 + CTRL_CMD_NEWMCAST_GRP = 7 + CTRL_CMD_DELMCAST_GRP = 8 + CTRL_CMD_GETMCAST_GRP = 9 + CTRL_CMD_GETPOLICY = 10 + + +class GenlCtrlAttrType(Enum): + CTRL_ATTR_FAMILY_ID = 1 + CTRL_ATTR_FAMILY_NAME = 2 + CTRL_ATTR_VERSION = 3 + CTRL_ATTR_HDRSIZE = 4 + CTRL_ATTR_MAXATTR = 5 + CTRL_ATTR_OPS = 6 + CTRL_ATTR_MCAST_GROUPS = 7 + CTRL_ATTR_POLICY = 8 + CTRL_ATTR_OP_POLICY = 9 + CTRL_ATTR_OP = 10 + + +class GenlCtrlAttrOpType(Enum): + CTRL_ATTR_OP_ID = 1 + CTRL_ATTR_OP_FLAGS = 2 + + +class GenlCtrlAttrMcastGroupsType(Enum): + CTRL_ATTR_MCAST_GRP_NAME = 1 + CTRL_ATTR_MCAST_GRP_ID = 2 + + +genl_ctrl_attrs = prepare_attrs_map( + [ + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID, NlAttrU16), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, NlAttrStr), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_VERSION, NlAttrU32), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_HDRSIZE, NlAttrU32), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_MAXATTR, NlAttrU32), + AttrDescr( + GenlCtrlAttrType.CTRL_ATTR_OPS, + NlAttrNested, + [ + AttrDescr(GenlCtrlAttrOpType.CTRL_ATTR_OP_ID, NlAttrU32), + AttrDescr(GenlCtrlAttrOpType.CTRL_ATTR_OP_FLAGS, NlAttrU32), + ], + True, + ), + AttrDescr( + GenlCtrlAttrType.CTRL_ATTR_MCAST_GROUPS, + NlAttrNested, + [ + AttrDescr( + GenlCtrlAttrMcastGroupsType.CTRL_ATTR_MCAST_GRP_NAME, NlAttrStr + ), + AttrDescr( + GenlCtrlAttrMcastGroupsType.CTRL_ATTR_MCAST_GRP_ID, NlAttrU32 + ), + ], + True, + ), + ] +) + + +class NetlinkGenlCtrlMessage(NetlinkGenlMessage): + messages = [ + NlMsgProps(GenlCtrlMsgType.CTRL_CMD_NEWFAMILY, NlMsgCategory.NEW), + NlMsgProps(GenlCtrlMsgType.CTRL_CMD_GETFAMILY, NlMsgCategory.GET), + NlMsgProps(GenlCtrlMsgType.CTRL_CMD_DELFAMILY, NlMsgCategory.DELETE), + ] + nl_attrs_map = genl_ctrl_attrs + family_name = GenlCtrlFamilyName + + +CarpFamilyName = "carp" + + +class CarpMsgType(Enum): + CARP_NL_CMD_UNSPEC = 0 + CARP_NL_CMD_GET = 1 + CARP_NL_CMD_SET = 2 + + +class CarpAttrType(Enum): + CARP_NL_UNSPEC = 0 + CARP_NL_VHID = 1 + CARP_NL_STATE = 2 + CARP_NL_ADVBASE = 3 + CARP_NL_ADVSKEW = 4 + CARP_NL_KEY = 5 + CARP_NL_IFINDEX = 6 + CARP_NL_ADDR = 7 + CARP_NL_ADDR6 = 8 + CARP_NL_IFNAME = 9 + + +carp_gen_attrs = prepare_attrs_map( + [ + AttrDescr(CarpAttrType.CARP_NL_VHID, NlAttrU32), + AttrDescr(CarpAttrType.CARP_NL_STATE, NlAttrU32), + AttrDescr(CarpAttrType.CARP_NL_ADVBASE, NlAttrS32), + AttrDescr(CarpAttrType.CARP_NL_ADVSKEW, NlAttrS32), + AttrDescr(CarpAttrType.CARP_NL_KEY, NlAttr), + AttrDescr(CarpAttrType.CARP_NL_IFINDEX, NlAttrU32), + AttrDescr(CarpAttrType.CARP_NL_ADDR, NlAttrIp4), + AttrDescr(CarpAttrType.CARP_NL_ADDR6, NlAttrIp6), + AttrDescr(CarpAttrType.CARP_NL_IFNAME, NlAttrStr), + ] +) + + +class CarpGenMessage(NetlinkGenlMessage): + messages = [ + NlMsgProps(CarpMsgType.CARP_NL_CMD_GET, NlMsgCategory.GET), + NlMsgProps(CarpMsgType.CARP_NL_CMD_SET, NlMsgCategory.NEW), + ] + nl_attrs_map = carp_gen_attrs + family_name = CarpFamilyName + + +KtestFamilyName = "ktest" + + +class KtestMsgType(Enum): + KTEST_CMD_UNSPEC = 0 + KTEST_CMD_LIST = 1 + KTEST_CMD_RUN = 2 + KTEST_CMD_NEWTEST = 3 + KTEST_CMD_NEWMESSAGE = 4 + + +class KtestAttrType(Enum): + KTEST_ATTR_MOD_NAME = 1 + KTEST_ATTR_TEST_NAME = 2 + KTEST_ATTR_TEST_DESCR = 3 + KTEST_ATTR_TEST_META = 4 + + +class KtestLogMsgType(Enum): + KTEST_MSG_START = 1 + KTEST_MSG_END = 2 + KTEST_MSG_LOG = 3 + KTEST_MSG_FAIL = 4 + + +class KtestMsgAttrType(Enum): + KTEST_MSG_ATTR_TS = 1 + KTEST_MSG_ATTR_FUNC = 2 + KTEST_MSG_ATTR_FILE = 3 + KTEST_MSG_ATTR_LINE = 4 + KTEST_MSG_ATTR_TEXT = 5 + KTEST_MSG_ATTR_LEVEL = 6 + KTEST_MSG_ATTR_META = 7 + + +class timespec(Structure): + _fields_ = [ + ("tv_sec", c_int64), + ("tv_nsec", c_long), + ] + + +class NlAttrTS(NlAttr): + DATA_LEN = sizeof(timespec) + + def __init__(self, nla_type, val): + self.ts = val + super().__init__(nla_type, b"") + + @property + def nla_len(self): + return NlAttr.HDR_LEN + self.DATA_LEN + + def _print_attr_value(self): + return " tv_sec={} tv_nsec={}".format(self.ts.tv_sec, self.ts.tv_nsec) + + @staticmethod + def _validate(data): + assert len(data) == NlAttr.HDR_LEN + NlAttrTS.DATA_LEN + nla_len, nla_type = struct.unpack("@HH", data[: NlAttr.HDR_LEN]) + assert nla_len == NlAttr.HDR_LEN + NlAttrTS.DATA_LEN + + @classmethod + def _parse(cls, data): + nla_len, nla_type = struct.unpack("@HH", data[: NlAttr.HDR_LEN]) + val = timespec.from_buffer_copy(data[NlAttr.HDR_LEN :]) + return cls(nla_type, val) + + def __bytes__(self): + return self._to_bytes(bytes(self.ts)) + + +ktest_info_attrs = prepare_attrs_map( + [ + AttrDescr(KtestAttrType.KTEST_ATTR_MOD_NAME, NlAttrStr), + AttrDescr(KtestAttrType.KTEST_ATTR_TEST_NAME, NlAttrStr), + AttrDescr(KtestAttrType.KTEST_ATTR_TEST_DESCR, NlAttrStr), + ] +) + + +ktest_msg_attrs = prepare_attrs_map( + [ + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_FUNC, NlAttrStr), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_FILE, NlAttrStr), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_LINE, NlAttrU32), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_TEXT, NlAttrStr), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_LEVEL, NlAttrU8), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_TS, NlAttrTS), + ] +) + + +class KtestInfoMessage(NetlinkGenlMessage): + messages = [ + NlMsgProps(KtestMsgType.KTEST_CMD_LIST, NlMsgCategory.GET), + NlMsgProps(KtestMsgType.KTEST_CMD_RUN, NlMsgCategory.NEW), + NlMsgProps(KtestMsgType.KTEST_CMD_NEWTEST, NlMsgCategory.NEW), + ] + nl_attrs_map = ktest_info_attrs + family_name = KtestFamilyName + + +class KtestMsgMessage(NetlinkGenlMessage): + messages = [ + NlMsgProps(KtestMsgType.KTEST_CMD_NEWMESSAGE, NlMsgCategory.NEW), + ] + nl_attrs_map = ktest_msg_attrs + family_name = KtestFamilyName + + +handler_classes = { + CarpFamilyName: [CarpGenMessage], + GenlCtrlFamilyName: [NetlinkGenlCtrlMessage], + KtestFamilyName: [KtestInfoMessage, KtestMsgMessage], +} diff --git a/tests/atf_python/sys/netlink/netlink_route.py b/tests/atf_python/sys/netlink/netlink_route.py new file mode 100644 index 000000000000..2cfeb57da13f --- /dev/null +++ b/tests/atf_python/sys/netlink/netlink_route.py @@ -0,0 +1,832 @@ +import socket +import struct +from ctypes import c_int +from ctypes import c_ubyte +from ctypes import c_uint +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from enum import auto +from enum import Enum + +from atf_python.sys.netlink.attrs import NlAttr +from atf_python.sys.netlink.attrs import NlAttrIp +from atf_python.sys.netlink.attrs import NlAttrNested +from atf_python.sys.netlink.attrs import NlAttrStr +from atf_python.sys.netlink.attrs import NlAttrU32 +from atf_python.sys.netlink.attrs import NlAttrU8 +from atf_python.sys.netlink.message import StdNetlinkMessage +from atf_python.sys.netlink.message import NlMsgProps +from atf_python.sys.netlink.message import NlMsgCategory +from atf_python.sys.netlink.utils import AttrDescr +from atf_python.sys.netlink.utils import get_bitmask_str +from atf_python.sys.netlink.utils import prepare_attrs_map + + +class RtattrType(Enum): + RTA_UNSPEC = 0 + RTA_DST = 1 + RTA_SRC = 2 + RTA_IIF = 3 + RTA_OIF = 4 + RTA_GATEWAY = 5 + RTA_PRIORITY = 6 + RTA_PREFSRC = 7 + RTA_METRICS = 8 + RTA_MULTIPATH = 9 + # RTA_PROTOINFO = 10 + RTA_KNH_ID = 10 + RTA_FLOW = 11 + RTA_CACHEINFO = 12 + RTA_SESSION = 13 + # RTA_MP_ALGO = 14 + RTA_RTFLAGS = 14 + RTA_TABLE = 15 + RTA_MARK = 16 + RTA_MFC_STATS = 17 + RTA_VIA = 18 + RTA_NEWDST = 19 + RTA_PREF = 20 + RTA_ENCAP_TYPE = 21 + RTA_ENCAP = 22 + RTA_EXPIRES = 23 + RTA_PAD = 24 + RTA_UID = 25 + RTA_TTL_PROPAGATE = 26 + RTA_IP_PROTO = 27 + RTA_SPORT = 28 + RTA_DPORT = 29 + RTA_NH_ID = 30 + + +class NlRtMsgType(Enum): + RTM_NEWLINK = 16 + RTM_DELLINK = 17 + RTM_GETLINK = 18 + RTM_SETLINK = 19 + RTM_NEWADDR = 20 + RTM_DELADDR = 21 + RTM_GETADDR = 22 + RTM_NEWROUTE = 24 + RTM_DELROUTE = 25 + RTM_GETROUTE = 26 + RTM_NEWNEIGH = 28 + RTM_DELNEIGH = 29 + RTM_GETNEIGH = 30 + RTM_NEWRULE = 32 + RTM_DELRULE = 33 + RTM_GETRULE = 34 + RTM_NEWQDISC = 36 + RTM_DELQDISC = 37 + RTM_GETQDISC = 38 + RTM_NEWTCLASS = 40 + RTM_DELTCLASS = 41 + RTM_GETTCLASS = 42 + RTM_NEWTFILTER = 44 + RTM_DELTFILTER = 45 + RTM_GETTFILTER = 46 + RTM_NEWACTION = 48 + RTM_DELACTION = 49 + RTM_GETACTION = 50 + RTM_NEWPREFIX = 52 + RTM_GETMULTICAST = 58 + RTM_GETANYCAST = 62 + RTM_NEWNEIGHTBL = 64 + RTM_GETNEIGHTBL = 66 + RTM_SETNEIGHTBL = 67 + RTM_NEWNDUSEROPT = 68 + RTM_NEWADDRLABEL = 72 + RTM_DELADDRLABEL = 73 + RTM_GETADDRLABEL = 74 + RTM_GETDCB = 78 + RTM_SETDCB = 79 + RTM_NEWNETCONF = 80 + RTM_GETNETCONF = 82 + RTM_NEWMDB = 84 + RTM_DELMDB = 85 + RTM_GETMDB = 86 + RTM_NEWNSID = 88 + RTM_DELNSID = 89 + RTM_GETNSID = 90 + RTM_NEWSTATS = 92 + RTM_GETSTATS = 94 + + +class RtAttr(Structure): + _fields_ = [ + ("rta_len", c_ushort), + ("rta_type", c_ushort), + ] + + +class RtMsgHdr(Structure): + _fields_ = [ + ("rtm_family", c_ubyte), + ("rtm_dst_len", c_ubyte), + ("rtm_src_len", c_ubyte), + ("rtm_tos", c_ubyte), + ("rtm_table", c_ubyte), + ("rtm_protocol", c_ubyte), + ("rtm_scope", c_ubyte), + ("rtm_type", c_ubyte), + ("rtm_flags", c_uint), + ] + + +class RtMsgFlags(Enum): + RTM_F_NOTIFY = 0x100 + RTM_F_CLONED = 0x200 + RTM_F_EQUALIZE = 0x400 + RTM_F_PREFIX = 0x800 + RTM_F_LOOKUP_TABLE = 0x1000 + RTM_F_FIB_MATCH = 0x2000 + RTM_F_OFFLOAD = 0x4000 + RTM_F_TRAP = 0x8000 + RTM_F_OFFLOAD_FAILED = 0x20000000 + + +class RtScope(Enum): + RT_SCOPE_UNIVERSE = 0 + RT_SCOPE_SITE = 200 + RT_SCOPE_LINK = 253 + RT_SCOPE_HOST = 254 + RT_SCOPE_NOWHERE = 255 + + +class RtType(Enum): + RTN_UNSPEC = 0 + RTN_UNICAST = auto() + RTN_LOCAL = auto() + RTN_BROADCAST = auto() + RTN_ANYCAST = auto() + RTN_MULTICAST = auto() + RTN_BLACKHOLE = auto() + RTN_UNREACHABLE = auto() + RTN_PROHIBIT = auto() + RTN_THROW = auto() + RTN_NAT = auto() + RTN_XRESOLVE = auto() + + +class RtProto(Enum): + RTPROT_UNSPEC = 0 + RTPROT_REDIRECT = 1 + RTPROT_KERNEL = 2 + RTPROT_BOOT = 3 + RTPROT_STATIC = 4 + RTPROT_GATED = 8 + RTPROT_RA = 9 + RTPROT_MRT = 10 + RTPROT_ZEBRA = 11 + RTPROT_BIRD = 12 + RTPROT_DNROUTED = 13 + RTPROT_XORP = 14 + RTPROT_NTK = 15 + RTPROT_DHCP = 16 + RTPROT_MROUTED = 17 + RTPROT_KEEPALIVED = 18 + RTPROT_BABEL = 42 + RTPROT_OPENR = 99 + RTPROT_BGP = 186 + RTPROT_ISIS = 187 + RTPROT_OSPF = 188 + RTPROT_RIP = 189 + RTPROT_EIGRP = 192 + + +class NlRtaxType(Enum): + RTAX_UNSPEC = 0 + RTAX_LOCK = auto() + RTAX_MTU = auto() + RTAX_WINDOW = auto() + RTAX_RTT = auto() + RTAX_RTTVAR = auto() + RTAX_SSTHRESH = auto() + RTAX_CWND = auto() + RTAX_ADVMSS = auto() + RTAX_REORDERING = auto() + RTAX_HOPLIMIT = auto() + RTAX_INITCWND = auto() + RTAX_FEATURES = auto() + RTAX_RTO_MIN = auto() + RTAX_INITRWND = auto() + RTAX_QUICKACK = auto() + RTAX_CC_ALGO = auto() + RTAX_FASTOPEN_NO_COOKIE = auto() + + +class RtFlagsBSD(Enum): + RTF_UP = 0x1 + RTF_GATEWAY = 0x2 + RTF_HOST = 0x4 + RTF_REJECT = 0x8 + RTF_DYNAMIC = 0x10 + RTF_MODIFIED = 0x20 + RTF_DONE = 0x40 + RTF_XRESOLVE = 0x200 + RTF_LLINFO = 0x400 + RTF_LLDATA = 0x400 + RTF_STATIC = 0x800 + RTF_BLACKHOLE = 0x1000 + RTF_PROTO2 = 0x4000 + RTF_PROTO1 = 0x8000 + RTF_PROTO3 = 0x40000 + RTF_FIXEDMTU = 0x80000 + RTF_PINNED = 0x100000 + RTF_LOCAL = 0x200000 + RTF_BROADCAST = 0x400000 + RTF_MULTICAST = 0x800000 + RTF_STICKY = 0x10000000 + RTF_RNH_LOCKED = 0x40000000 + RTF_GWFLAG_COMPAT = 0x80000000 + + +class NlRtGroup(Enum): + RTNLGRP_NONE = 0 + RTNLGRP_LINK = auto() + RTNLGRP_NOTIFY = auto() + RTNLGRP_NEIGH = auto() + RTNLGRP_TC = auto() + RTNLGRP_IPV4_IFADDR = auto() + RTNLGRP_IPV4_MROUTE = auto() + RTNLGRP_IPV4_ROUTE = auto() + RTNLGRP_IPV4_RULE = auto() + RTNLGRP_IPV6_IFADDR = auto() + RTNLGRP_IPV6_MROUTE = auto() + RTNLGRP_IPV6_ROUTE = auto() + RTNLGRP_IPV6_IFINFO = auto() + RTNLGRP_DECnet_IFADDR = auto() + RTNLGRP_NOP2 = auto() + RTNLGRP_DECnet_ROUTE = auto() + RTNLGRP_DECnet_RULE = auto() + RTNLGRP_NOP4 = auto() + RTNLGRP_IPV6_PREFIX = auto() + RTNLGRP_IPV6_RULE = auto() + RTNLGRP_ND_USEROPT = auto() + RTNLGRP_PHONET_IFADDR = auto() + RTNLGRP_PHONET_ROUTE = auto() + RTNLGRP_DCB = auto() + RTNLGRP_IPV4_NETCONF = auto() + RTNLGRP_IPV6_NETCONF = auto() + RTNLGRP_MDB = auto() + RTNLGRP_MPLS_ROUTE = auto() + RTNLGRP_NSID = auto() + RTNLGRP_MPLS_NETCONF = auto() + RTNLGRP_IPV4_MROUTE_R = auto() + RTNLGRP_IPV6_MROUTE_R = auto() + RTNLGRP_NEXTHOP = auto() + RTNLGRP_BRVLAN = auto() + + +class IfinfoMsg(Structure): + _fields_ = [ + ("ifi_family", c_ubyte), + ("__ifi_pad", c_ubyte), + ("ifi_type", c_ushort), + ("ifi_index", c_int), + ("ifi_flags", c_uint), + ("ifi_change", c_uint), + ] + + +class IflattrType(Enum): + IFLA_UNSPEC = 0 + IFLA_ADDRESS = 1 + IFLA_BROADCAST = 2 + IFLA_IFNAME = 3 + IFLA_MTU = 4 + IFLA_LINK = 5 + IFLA_QDISC = 6 + IFLA_STATS = 7 + IFLA_COST = 8 + IFLA_PRIORITY = 9 + IFLA_MASTER = 10 + IFLA_WIRELESS = 11 + IFLA_PROTINFO = 12 + IFLA_TXQLEN = 13 + IFLA_MAP = 14 + IFLA_WEIGHT = 15 + IFLA_OPERSTATE = 16 + IFLA_LINKMODE = 17 + IFLA_LINKINFO = 18 + IFLA_NET_NS_PID = 19 + IFLA_IFALIAS = 20 + IFLA_NUM_VF = 21 + IFLA_VFINFO_LIST = 22 + IFLA_STATS64 = 23 + IFLA_VF_PORTS = 24 + IFLA_PORT_SELF = 25 + IFLA_AF_SPEC = 26 + IFLA_GROUP = 27 + IFLA_NET_NS_FD = 28 + IFLA_EXT_MASK = 29 + IFLA_PROMISCUITY = 30 + IFLA_NUM_TX_QUEUES = 31 + IFLA_NUM_RX_QUEUES = 32 + IFLA_CARRIER = 33 + IFLA_PHYS_PORT_ID = 34 + IFLA_CARRIER_CHANGES = 35 + IFLA_PHYS_SWITCH_ID = 36 + IFLA_LINK_NETNSID = 37 + IFLA_PHYS_PORT_NAME = 38 + IFLA_PROTO_DOWN = 39 + IFLA_GSO_MAX_SEGS = 40 + IFLA_GSO_MAX_SIZE = 41 + IFLA_PAD = 42 + IFLA_XDP = 43 + IFLA_EVENT = 44 + IFLA_NEW_NETNSID = 45 + IFLA_IF_NETNSID = 46 + IFLA_CARRIER_UP_COUNT = 47 + IFLA_CARRIER_DOWN_COUNT = 48 + IFLA_NEW_IFINDEX = 49 + IFLA_MIN_MTU = 50 + IFLA_MAX_MTU = 51 + IFLA_PROP_LIST = 52 + IFLA_ALT_IFNAME = 53 + IFLA_PERM_ADDRESS = 54 + IFLA_PROTO_DOWN_REASON = 55 + IFLA_PARENT_DEV_NAME = 56 + IFLA_PARENT_DEV_BUS_NAME = 57 + IFLA_GRO_MAX_SIZE = 58 + IFLA_TSO_MAX_SEGS = 59 + IFLA_ALLMULTI = 60 + IFLA_DEVLINK_PORT = 61 + IFLA_GSO_IPV4_MAX_SIZE = 62 + IFLA_GRO_IPV4_MAX_SIZE = 63 + IFLA_FREEBSD = 64 + + +class IflafAttrType(Enum): + IFLAF_UNSPEC = 0 + IFLAF_ORIG_IFNAME = 1 + IFLAF_ORIG_HWADDR = 2 + + +class IflinkInfo(Enum): + IFLA_INFO_UNSPEC = 0 + IFLA_INFO_KIND = auto() + IFLA_INFO_DATA = auto() + IFLA_INFO_XSTATS = auto() + IFLA_INFO_SLAVE_KIND = auto() + IFLA_INFO_SLAVE_DATA = auto() + + +class IfLinkInfoDataVlan(Enum): + IFLA_VLAN_UNSPEC = 0 + IFLA_VLAN_ID = auto() + IFLA_VLAN_FLAGS = auto() + IFLA_VLAN_EGRESS_QOS = auto() + IFLA_VLAN_INGRESS_QOS = auto() + IFLA_VLAN_PROTOCOL = auto() + + +class IfaddrMsg(Structure): + _fields_ = [ + ("ifa_family", c_ubyte), + ("ifa_prefixlen", c_ubyte), + ("ifa_flags", c_ubyte), + ("ifa_scope", c_ubyte), + ("ifa_index", c_uint), + ] + + +class IfaAttrType(Enum): + IFA_UNSPEC = 0 + IFA_ADDRESS = 1 + IFA_LOCAL = 2 + IFA_LABEL = 3 + IFA_BROADCAST = 4 + IFA_ANYCAST = 5 + IFA_CACHEINFO = 6 + IFA_MULTICAST = 7 + IFA_FLAGS = 8 + IFA_RT_PRIORITY = 9 + IFA_TARGET_NETNSID = 10 + IFA_FREEBSD = 11 + + +class IfafAttrType(Enum): + IFAF_UNSPEC = 0 + IFAF_VHID = 1 + IFAF_FLAGS = 2 + + +class IfaCacheInfo(Structure): + _fields_ = [ + ("ifa_prefered", c_uint), # seconds till the end of the prefix considered preferred + ("ifa_valid", c_uint), # seconds till the end of the prefix considered valid + ("cstamp", c_uint), # creation time in 1ms intervals from the boot time + ("tstamp", c_uint), # update time in 1ms intervals from the boot time + ] + + +class IfaFlags(Enum): + IFA_F_TEMPORARY = 0x01 + IFA_F_NODAD = 0x02 + IFA_F_OPTIMISTIC = 0x04 + IFA_F_DADFAILED = 0x08 + IFA_F_HOMEADDRESS = 0x10 + IFA_F_DEPRECATED = 0x20 + IFA_F_TENTATIVE = 0x40 + IFA_F_PERMANENT = 0x80 + IFA_F_MANAGETEMPADDR = 0x100 + IFA_F_NOPREFIXROUTE = 0x200 + IFA_F_MCAUTOJOIN = 0x400 + IFA_F_STABLE_PRIVACY = 0x800 + + +class IfafFlags6(Enum): + IN6_IFF_ANYCAST = 0x01 + IN6_IFF_TENTATIVE = 0x02 + IN6_IFF_DUPLICATED = 0x04 + IN6_IFF_DETACHED = 0x08 + IN6_IFF_DEPRECATED = 0x10 + IN6_IFF_NODAD = 0x20 + IN6_IFF_AUTOCONF = 0x40 + IN6_IFF_TEMPORARY = 0x80 + IN6_IFF_PREFER_SOURCE = 0x100 + + +class NdMsg(Structure): + _fields_ = [ + ("ndm_family", c_ubyte), + ("ndm_pad1", c_ubyte), + ("ndm_pad2", c_ubyte), + ("ndm_ifindex", c_uint), + ("ndm_state", c_ushort), + ("ndm_flags", c_ubyte), + ("ndm_type", c_ubyte), + ] + + +class NdAttrType(Enum): + NDA_UNSPEC = 0 + NDA_DST = 1 + NDA_LLADDR = 2 + NDA_CACHEINFO = 3 + NDA_PROBES = 4 + NDA_VLAN = 5 + NDA_PORT = 6 + NDA_VNI = 7 + NDA_IFINDEX = 8 + NDA_MASTER = 9 + NDA_LINK_NETNSID = 10 + NDA_SRC_VNI = 11 + NDA_PROTOCOL = 12 + NDA_NH_ID = 13 + NDA_FDB_EXT_ATTRS = 14 + NDA_FLAGS_EXT = 15 + NDA_NDM_STATE_MASK = 16 + NDA_NDM_FLAGS_MASK = 17 + + +class NlAttrRtFlags(NlAttrU32): + def _print_attr_value(self): + s = get_bitmask_str(RtFlagsBSD, self.u32) + return " rtflags={}".format(s) + + +class NlAttrIfindex(NlAttrU32): + def _print_attr_value(self): + try: + ifname = socket.if_indextoname(self.u32) + return " iface={}(#{})".format(ifname, self.u32) + except OSError: + pass + return " iface=if#{}".format(self.u32) + + +class NlAttrTable(NlAttrU32): + def _print_attr_value(self): + return " rtable={}".format(self.u32) + + +class NlAttrNhId(NlAttrU32): + def _print_attr_value(self): + return " nh_id={}".format(self.u32) + + +class NlAttrKNhId(NlAttrU32): + def _print_attr_value(self): + return " knh_id={}".format(self.u32) + + +class NlAttrMac(NlAttr): + def _print_attr_value(self): + return ' mac="' + ":".join(["{:02X}".format(b) for b in self._data]) + '"' + + +class NlAttrIfStats(NlAttr): + def _print_attr_value(self): + return " stats={...}" + + +class NlAttrCacheInfo(NlAttr): + def __init__(self, nla_type, data): + super().__init__(nla_type, data) + self.ci = IfaCacheInfo.from_buffer_copy(data) + + @staticmethod + def _validate(data): + nla_len, nla_type = struct.unpack("@HH", data[:4]) + data_len = nla_len - 4 + if data_len != sizeof(IfaCacheInfo): + raise ValueError( + "Error validating attr {}: wrong size".format(nla_type) + ) # noqa: E501 + + def _print_attr_value(self): + return " ifa_prefered={} ifa_valid={} cstamp={} tstamp={}".format( + self.ci.ifa_prefered, self.ci.ifa_valid, self.ci.cstamp, self.ci.tstamp) + + +class NlAttrVia(NlAttr): + def __init__(self, nla_type, family, addr: str): + super().__init__(nla_type, b"") + self.addr = addr + self.family = family + + @staticmethod + def _validate(data): + nla_len, nla_type = struct.unpack("@HH", data[:4]) + data_len = nla_len - 4 + if data_len == 0: + raise ValueError( + "Error validating attr {}: empty data".format(nla_type) + ) # noqa: E501 + family = int(data_len[0]) + if family not in (socket.AF_INET, socket.AF_INET6): + raise ValueError( + "Error validating attr {}: unsupported AF {}".format( # noqa: E501 + nla_type, family + ) + ) + if family == socket.AF_INET: + expected_len = 1 + 4 + else: + expected_len = 1 + 16 + if data_len != expected_len: + raise ValueError( + "Error validating attr {}: expected len {} got {}".format( # noqa: E501 + nla_type, expected_len, data_len + ) + ) + + @property + def nla_len(self): + if self.family == socket.AF_INET6: + return 21 + else: + return 9 + + @classmethod + def _parse(cls, data): + nla_len, nla_type, family = struct.unpack("@HHB", data[:5]) + off = 5 + if family == socket.AF_INET: + addr = socket.inet_ntop(family, data[off:off + 4]) + else: + addr = socket.inet_ntop(family, data[off:off + 16]) + return cls(nla_type, family, addr) + + def __bytes__(self): + addr = socket.inet_pton(self.family, self.addr) + return self._to_bytes(struct.pack("@B", self.family) + addr) + + def _print_attr_value(self): + return " via={}".format(self.addr) + + +rtnl_route_attrs = prepare_attrs_map( + [ + AttrDescr(RtattrType.RTA_DST, NlAttrIp), + AttrDescr(RtattrType.RTA_SRC, NlAttrIp), + AttrDescr(RtattrType.RTA_IIF, NlAttrIfindex), + AttrDescr(RtattrType.RTA_OIF, NlAttrIfindex), + AttrDescr(RtattrType.RTA_GATEWAY, NlAttrIp), + AttrDescr(RtattrType.RTA_TABLE, NlAttrTable), + AttrDescr(RtattrType.RTA_PRIORITY, NlAttrU32), + AttrDescr(RtattrType.RTA_VIA, NlAttrVia), + AttrDescr(RtattrType.RTA_NH_ID, NlAttrNhId), + AttrDescr(RtattrType.RTA_KNH_ID, NlAttrKNhId), + AttrDescr(RtattrType.RTA_RTFLAGS, NlAttrRtFlags), + AttrDescr( + RtattrType.RTA_METRICS, + NlAttrNested, + [ + AttrDescr(NlRtaxType.RTAX_MTU, NlAttrU32), + ], + ), + ] +) + +rtnl_ifla_attrs = prepare_attrs_map( + [ + AttrDescr(IflattrType.IFLA_ADDRESS, NlAttrMac), + AttrDescr(IflattrType.IFLA_BROADCAST, NlAttrMac), + AttrDescr(IflattrType.IFLA_IFNAME, NlAttrStr), + AttrDescr(IflattrType.IFLA_MTU, NlAttrU32), + AttrDescr(IflattrType.IFLA_LINK, NlAttrU32), + AttrDescr(IflattrType.IFLA_PROMISCUITY, NlAttrU32), + AttrDescr(IflattrType.IFLA_OPERSTATE, NlAttrU8), + AttrDescr(IflattrType.IFLA_CARRIER, NlAttrU8), + AttrDescr(IflattrType.IFLA_IFALIAS, NlAttrStr), + AttrDescr(IflattrType.IFLA_STATS64, NlAttrIfStats), + AttrDescr(IflattrType.IFLA_NEW_IFINDEX, NlAttrU32), + AttrDescr( + IflattrType.IFLA_LINKINFO, + NlAttrNested, + [ + AttrDescr(IflinkInfo.IFLA_INFO_KIND, NlAttrStr), + AttrDescr(IflinkInfo.IFLA_INFO_DATA, NlAttr), + ], + ), + AttrDescr( + IflattrType.IFLA_FREEBSD, + NlAttrNested, + [ + AttrDescr(IflafAttrType.IFLAF_ORIG_HWADDR, NlAttrMac), + ], + ), + ] +) + +rtnl_ifa_attrs = prepare_attrs_map( + [ + AttrDescr(IfaAttrType.IFA_ADDRESS, NlAttrIp), + AttrDescr(IfaAttrType.IFA_LOCAL, NlAttrIp), + AttrDescr(IfaAttrType.IFA_LABEL, NlAttrStr), + AttrDescr(IfaAttrType.IFA_BROADCAST, NlAttrIp), + AttrDescr(IfaAttrType.IFA_ANYCAST, NlAttrIp), + AttrDescr(IfaAttrType.IFA_FLAGS, NlAttrU32), + AttrDescr(IfaAttrType.IFA_CACHEINFO, NlAttrCacheInfo), + AttrDescr( + IfaAttrType.IFA_FREEBSD, + NlAttrNested, + [ + AttrDescr(IfafAttrType.IFAF_VHID, NlAttrU32), + AttrDescr(IfafAttrType.IFAF_FLAGS, NlAttrU32), + ], + ), + ] +) + + +rtnl_nd_attrs = prepare_attrs_map( + [ + AttrDescr(NdAttrType.NDA_DST, NlAttrIp), + AttrDescr(NdAttrType.NDA_IFINDEX, NlAttrIfindex), + AttrDescr(NdAttrType.NDA_FLAGS_EXT, NlAttrU32), + AttrDescr(NdAttrType.NDA_LLADDR, NlAttrMac), + ] +) + + +class BaseNetlinkRtMessage(StdNetlinkMessage): + pass + + +class NetlinkRtMessage(BaseNetlinkRtMessage): + messages = [ + NlMsgProps(NlRtMsgType.RTM_NEWROUTE, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELROUTE, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETROUTE, NlMsgCategory.GET), + ] + nl_attrs_map = rtnl_route_attrs + + def __init__(self, helper, nlm_type): + super().__init__(helper, nlm_type) + self.base_hdr = RtMsgHdr() + + def parse_base_header(self, data): + if len(data) < sizeof(RtMsgHdr): + raise ValueError("length less than rtmsg header") + rtm_hdr = RtMsgHdr.from_buffer_copy(data) + return (rtm_hdr, sizeof(RtMsgHdr)) + + def print_base_header(self, hdr, prepend=""): + family = self.helper.get_af_name(hdr.rtm_family) + print( + "{}family={}, dst_len={}, src_len={}, tos={}, table={}, protocol={}({}), scope={}({}), type={}({}), flags={}({})".format( # noqa: E501 + prepend, + family, + hdr.rtm_dst_len, + hdr.rtm_src_len, + hdr.rtm_tos, + hdr.rtm_table, + self.helper.get_attr_byval(RtProto, hdr.rtm_protocol), + hdr.rtm_protocol, + self.helper.get_attr_byval(RtScope, hdr.rtm_scope), + hdr.rtm_scope, + self.helper.get_attr_byval(RtType, hdr.rtm_type), + hdr.rtm_type, + self.helper.get_bitmask_str(RtMsgFlags, hdr.rtm_flags), + hdr.rtm_flags, + ) + ) + + +class NetlinkIflaMessage(BaseNetlinkRtMessage): + messages = [ + NlMsgProps(NlRtMsgType.RTM_NEWLINK, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELLINK, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETLINK, NlMsgCategory.GET), + ] + nl_attrs_map = rtnl_ifla_attrs + + def __init__(self, helper, nlm_type): + super().__init__(helper, nlm_type) + self.base_hdr = IfinfoMsg() + + def parse_base_header(self, data): + if len(data) < sizeof(IfinfoMsg): + raise ValueError("length less than IfinfoMsg header") + rtm_hdr = IfinfoMsg.from_buffer_copy(data) + return (rtm_hdr, sizeof(IfinfoMsg)) + + def print_base_header(self, hdr, prepend=""): + family = self.helper.get_af_name(hdr.ifi_family) + print( + "{}family={}, ifi_type={}, ifi_index={}, ifi_flags={}, ifi_change={}".format( # noqa: E501 + prepend, + family, + hdr.ifi_type, + hdr.ifi_index, + hdr.ifi_flags, + hdr.ifi_change, + ) + ) + + +class NetlinkIfaMessage(BaseNetlinkRtMessage): + messages = [ + NlMsgProps(NlRtMsgType.RTM_NEWADDR, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELADDR, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETADDR, NlMsgCategory.GET), + ] + nl_attrs_map = rtnl_ifa_attrs + + def __init__(self, helper, nlm_type): + super().__init__(helper, nlm_type) + self.base_hdr = IfaddrMsg() + + def parse_base_header(self, data): + if len(data) < sizeof(IfaddrMsg): + raise ValueError("length less than IfaddrMsg header") + rtm_hdr = IfaddrMsg.from_buffer_copy(data) + return (rtm_hdr, sizeof(IfaddrMsg)) + + def print_base_header(self, hdr, prepend=""): + family = self.helper.get_af_name(hdr.ifa_family) + print( + "{}family={}, ifa_prefixlen={}, ifa_flags={}, ifa_scope={}, ifa_index={}".format( # noqa: E501 + prepend, + family, + hdr.ifa_prefixlen, + hdr.ifa_flags, + hdr.ifa_scope, + hdr.ifa_index, + ) + ) + + +class NetlinkNdMessage(BaseNetlinkRtMessage): + messages = [ + NlMsgProps(NlRtMsgType.RTM_NEWNEIGH, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELNEIGH, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETNEIGH, NlMsgCategory.GET), + ] + nl_attrs_map = rtnl_nd_attrs + + def __init__(self, helper, nlm_type): + super().__init__(helper, nlm_type) + self.base_hdr = NdMsg() + + def parse_base_header(self, data): + if len(data) < sizeof(NdMsg): + raise ValueError("length less than NdMsg header") + nd_hdr = NdMsg.from_buffer_copy(data) + return (nd_hdr, sizeof(NdMsg)) + + def print_base_header(self, hdr, prepend=""): + family = self.helper.get_af_name(hdr.ndm_family) + print( + "{}family={}, ndm_ifindex={}, ndm_state={}, ndm_flags={}".format( # noqa: E501 + prepend, + family, + hdr.ndm_ifindex, + hdr.ndm_state, + hdr.ndm_flags, + ) + ) + + +handler_classes = { + "netlink_route": [ + NetlinkRtMessage, + NetlinkIflaMessage, + NetlinkIfaMessage, + NetlinkNdMessage, + ], +} diff --git a/tests/atf_python/sys/netlink/utils.py b/tests/atf_python/sys/netlink/utils.py new file mode 100644 index 000000000000..f1d0ba3321ed --- /dev/null +++ b/tests/atf_python/sys/netlink/utils.py @@ -0,0 +1,80 @@ +#!/usr/local/bin/python3 +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple + + +class NlConst: + AF_NETLINK = 38 + NETLINK_ROUTE = 0 + NETLINK_GENERIC = 16 + GENL_ID_CTRL = 16 + + +def roundup2(val: int, num: int) -> int: + if val % num: + return (val | (num - 1)) + 1 + else: + return val + + +def align4(val: int) -> int: + return roundup2(val, 4) + + +def enum_or_int(val) -> int: + if isinstance(val, Enum): + return val.value + return val + + +class AttrDescr(NamedTuple): + val: Enum + cls: "NlAttr" + child_map: Any = None + is_array: bool = False + + +def prepare_attrs_map(attrs: List[AttrDescr]) -> Dict[str, Dict]: + ret = {} + for ad in attrs: + ret[ad.val.value] = {"ad": ad} + if ad.child_map: + ret[ad.val.value]["child"] = prepare_attrs_map(ad.child_map) + ret[ad.val.value]["is_array"] = ad.is_array + return ret + + +def build_propmap(cls): + ret = {} + for prop in dir(cls): + if not prop.startswith("_"): + ret[getattr(cls, prop).value] = prop + return ret + + +def get_bitmask_map(propmap, val): + v = 1 + ret = {} + while val: + if v & val: + if v in propmap: + ret[v] = propmap[v] + else: + ret[v] = hex(v) + val -= v + v *= 2 + return ret + + +def get_bitmask_str(cls, val): + if isinstance(cls, type): + pmap = build_propmap(cls) + else: + pmap = {} + for _cls in cls: + pmap.update(build_propmap(_cls)) + bmap = get_bitmask_map(pmap, val) + return ",".join([v for k, v in bmap.items()]) |