diff options
Diffstat (limited to 'tests/atf_python/sys/netlink/netlink.py')
-rw-r--r-- | tests/atf_python/sys/netlink/netlink.py | 417 |
1 files changed, 417 insertions, 0 deletions
diff --git a/tests/atf_python/sys/netlink/netlink.py b/tests/atf_python/sys/netlink/netlink.py new file mode 100644 index 000000000000..f8f886b09b24 --- /dev/null +++ b/tests/atf_python/sys/netlink/netlink.py @@ -0,0 +1,417 @@ +#!/usr/local/bin/python3 +import os +import socket +import sys +from ctypes import c_int +from ctypes import c_ubyte +from ctypes import c_uint +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from enum import auto +from enum import Enum + +from atf_python.sys.netlink.attrs import NlAttr +from atf_python.sys.netlink.attrs import NlAttrStr +from atf_python.sys.netlink.attrs import NlAttrU32 +from atf_python.sys.netlink.base_headers import GenlMsgHdr +from atf_python.sys.netlink.base_headers import NlmBaseFlags +from atf_python.sys.netlink.base_headers import Nlmsghdr +from atf_python.sys.netlink.base_headers import NlMsgType +from atf_python.sys.netlink.message import BaseNetlinkMessage +from atf_python.sys.netlink.message import NlMsgCategory +from atf_python.sys.netlink.message import NlMsgProps +from atf_python.sys.netlink.message import StdNetlinkMessage +from atf_python.sys.netlink.netlink_generic import GenlCtrlAttrType +from atf_python.sys.netlink.netlink_generic import GenlCtrlMsgType +from atf_python.sys.netlink.netlink_generic import handler_classes as genl_classes +from atf_python.sys.netlink.netlink_route import handler_classes as rt_classes +from atf_python.sys.netlink.utils import align4 +from atf_python.sys.netlink.utils import AttrDescr +from atf_python.sys.netlink.utils import build_propmap +from atf_python.sys.netlink.utils import enum_or_int +from atf_python.sys.netlink.utils import get_bitmask_map +from atf_python.sys.netlink.utils import NlConst +from atf_python.sys.netlink.utils import prepare_attrs_map + + +class SockaddrNl(Structure): + _fields_ = [ + ("nl_len", c_ubyte), + ("nl_family", c_ubyte), + ("nl_pad", c_ushort), + ("nl_pid", c_uint), + ("nl_groups", c_uint), + ] + + +class Nlmsgdone(Structure): + _fields_ = [ + ("error", c_int), + ] + + +class Nlmsgerr(Structure): + _fields_ = [ + ("error", c_int), + ("msg", Nlmsghdr), + ] + + +class NlErrattrType(Enum): + NLMSGERR_ATTR_UNUSED = 0 + NLMSGERR_ATTR_MSG = auto() + NLMSGERR_ATTR_OFFS = auto() + NLMSGERR_ATTR_COOKIE = auto() + NLMSGERR_ATTR_POLICY = auto() + + +class AddressFamilyLinux(Enum): + AF_INET = socket.AF_INET + AF_INET6 = socket.AF_INET6 + AF_NETLINK = 16 + + +class AddressFamilyBsd(Enum): + AF_INET = socket.AF_INET + AF_INET6 = socket.AF_INET6 + AF_NETLINK = 38 + + +class NlHelper: + def __init__(self): + self._pmap = {} + self._af_cls = self.get_af_cls() + self._seq_counter = 1 + self.pid = os.getpid() + + def get_seq(self): + ret = self._seq_counter + self._seq_counter += 1 + return ret + + def get_af_cls(self): + if sys.platform.startswith("freebsd"): + cls = AddressFamilyBsd + else: + cls = AddressFamilyLinux + return cls + + def get_propmap(self, cls): + if cls not in self._pmap: + self._pmap[cls] = build_propmap(cls) + return self._pmap[cls] + + def get_name_propmap(self, cls): + ret = {} + for prop in dir(cls): + if not prop.startswith("_"): + ret[prop] = getattr(cls, prop).value + return ret + + def get_attr_byval(self, cls, attr_val): + propmap = self.get_propmap(cls) + return propmap.get(attr_val) + + def get_af_name(self, family): + v = self.get_attr_byval(self._af_cls, family) + if v is not None: + return v + return "af#{}".format(family) + + def get_af_value(self, family_str: str) -> int: + propmap = self.get_name_propmap(self._af_cls) + return propmap.get(family_str) + + def get_bitmask_str(self, cls, val): + bmap = get_bitmask_map(self.get_propmap(cls), val) + return ",".join([v for k, v in bmap.items()]) + + @staticmethod + def get_bitmask_str_uncached(cls, val): + pmap = NlHelper.build_propmap(cls) + bmap = NlHelper.get_bitmask_map(pmap, val) + return ",".join([v for k, v in bmap.items()]) + + +nldone_attrs = prepare_attrs_map([]) + +nlerr_attrs = prepare_attrs_map( + [ + AttrDescr(NlErrattrType.NLMSGERR_ATTR_MSG, NlAttrStr), + AttrDescr(NlErrattrType.NLMSGERR_ATTR_OFFS, NlAttrU32), + AttrDescr(NlErrattrType.NLMSGERR_ATTR_COOKIE, NlAttr), + ] +) + + +class NetlinkDoneMessage(StdNetlinkMessage): + messages = [NlMsgProps(NlMsgType.NLMSG_DONE, NlMsgCategory.ACK)] + nl_attrs_map = nldone_attrs + + @property + def error_code(self): + return self.base_hdr.error + + def parse_base_header(self, data): + if len(data) < sizeof(Nlmsgdone): + raise ValueError("length less than nlmsgdone header") + done_hdr = Nlmsgdone.from_buffer_copy(data) + sz = sizeof(Nlmsgdone) + return (done_hdr, sz) + + def print_base_header(self, hdr, prepend=""): + print("{}error={}".format(prepend, hdr.error)) + + +class NetlinkErrorMessage(StdNetlinkMessage): + messages = [NlMsgProps(NlMsgType.NLMSG_ERROR, NlMsgCategory.ACK)] + nl_attrs_map = nlerr_attrs + + @property + def error_code(self): + return self.base_hdr.error + + @property + def error_str(self): + nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_MSG) + if nla: + return nla.text + return None + + @property + def error_offset(self): + nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_OFFS) + if nla: + return nla.u32 + return None + + @property + def cookie(self): + return self.get_nla(NlErrattrType.NLMSGERR_ATTR_COOKIE) + + def parse_base_header(self, data): + if len(data) < sizeof(Nlmsgerr): + raise ValueError("length less than nlmsgerr header") + err_hdr = Nlmsgerr.from_buffer_copy(data) + sz = sizeof(Nlmsgerr) + if (self.nl_hdr.nlmsg_flags & 0x100) == 0: + sz += align4(err_hdr.msg.nlmsg_len - sizeof(Nlmsghdr)) + return (err_hdr, sz) + + def print_base_header(self, errhdr, prepend=""): + print("{}error={}, ".format(prepend, errhdr.error), end="") + hdr = errhdr.msg + print( + "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( + prepend, + hdr.nlmsg_len, + "msg#{}".format(hdr.nlmsg_type), + self.helper.get_bitmask_str(NlmBaseFlags, hdr.nlmsg_flags), + hdr.nlmsg_flags, + hdr.nlmsg_seq, + hdr.nlmsg_pid, + ) + ) + + +core_classes = { + "netlink_core": [ + NetlinkDoneMessage, + NetlinkErrorMessage, + ], +} + + +class Nlsock: + HANDLER_CLASSES = [core_classes, rt_classes, genl_classes] + + def __init__(self, family, helper): + self.helper = helper + self.sock_fd = self._setup_netlink(family) + self._sock_family = family + self._data = bytes() + self.msgmap = self.build_msgmap() + self._family_map = { + NlConst.GENL_ID_CTRL: "nlctrl", + } + + def build_msgmap(self): + handler_classes = {} + for d in self.HANDLER_CLASSES: + handler_classes.update(d) + xmap = {} + # 'family_name': [class.messages[MsgProps.msg], ] + for family_id, family_classes in handler_classes.items(): + xmap[family_id] = {} + for cls in family_classes: + for msg_props in cls.messages: + xmap[family_id][enum_or_int(msg_props.msg)] = cls + return xmap + + def _setup_netlink(self, netlink_family) -> int: + family = self.helper.get_af_value("AF_NETLINK") + s = socket.socket(family, socket.SOCK_RAW, netlink_family) + s.setsockopt(270, 10, 1) # NETLINK_CAP_ACK + s.setsockopt(270, 11, 1) # NETLINK_EXT_ACK + return s + + def set_groups(self, mask: int): + self.sock_fd.setsockopt(socket.SOL_SOCKET, 1, mask) + # snl = SockaddrNl(nl_len = sizeof(SockaddrNl), nl_family=38, + # nl_pid=self.pid, nl_groups=mask) + # xbuffer = create_string_buffer(sizeof(SockaddrNl)) + # memmove(xbuffer, addressof(snl), sizeof(SockaddrNl)) + # k = struct.pack("@BBHII", 12, 38, 0, self.pid, mask) + # self.sock_fd.bind(k) + + def join_group(self, group_id: int): + self.sock_fd.setsockopt(270, 1, group_id) + + def write_message(self, msg, verbose=True): + if verbose: + print("vvvvvvvv OUT vvvvvvvv") + msg.print_message() + msg_bytes = bytes(msg) + try: + ret = os.write(self.sock_fd.fileno(), msg_bytes) + assert ret == len(msg_bytes) + except Exception as e: + print("write({}) -> {}".format(len(msg_bytes), e)) + + def parse_message(self, data: bytes): + if len(data) < sizeof(Nlmsghdr): + raise Exception("Short read from nl: {} bytes".format(len(data))) + hdr = Nlmsghdr.from_buffer_copy(data) + if hdr.nlmsg_type < 16: + family_name = "netlink_core" + nlmsg_type = hdr.nlmsg_type + elif self._sock_family == NlConst.NETLINK_ROUTE: + family_name = "netlink_route" + nlmsg_type = hdr.nlmsg_type + else: + # Genetlink + if len(data) < sizeof(Nlmsghdr) + sizeof(GenlMsgHdr): + raise Exception("Short read from genl: {} bytes".format(len(data))) + family_name = self._family_map.get(hdr.nlmsg_type, "") + ghdr = GenlMsgHdr.from_buffer_copy(data[sizeof(Nlmsghdr):]) + nlmsg_type = ghdr.cmd + cls = self.msgmap.get(family_name, {}).get(nlmsg_type) + if not cls: + cls = BaseNetlinkMessage + return cls.from_bytes(self.helper, data) + + def get_genl_family_id(self, family_name): + hdr = Nlmsghdr( + nlmsg_type=NlConst.GENL_ID_CTRL, + nlmsg_flags=NlmBaseFlags.NLM_F_REQUEST.value, + nlmsg_seq=self.helper.get_seq(), + ) + ghdr = GenlMsgHdr(cmd=GenlCtrlMsgType.CTRL_CMD_GETFAMILY.value) + nla = NlAttrStr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, family_name) + hdr.nlmsg_len = sizeof(Nlmsghdr) + sizeof(GenlMsgHdr) + len(bytes(nla)) + + msg_bytes = bytes(hdr) + bytes(ghdr) + bytes(nla) + self.write_data(msg_bytes) + while True: + rx_msg = self.read_message() + if hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: + if rx_msg.is_type(NlMsgType.NLMSG_ERROR): + if rx_msg.error_code != 0: + raise ValueError("unable to get family {}".format(family_name)) + else: + family_id = rx_msg.get_nla(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID).u16 + self._family_map[family_id] = family_name + return family_id + raise ValueError("unable to get family {}".format(family_name)) + + def write_data(self, data: bytes): + self.sock_fd.send(data) + + def read_data(self): + while True: + data = self.sock_fd.recv(65535) + self._data += data + if len(self._data) >= sizeof(Nlmsghdr): + break + + def read_message(self) -> bytes: + if len(self._data) < sizeof(Nlmsghdr): + self.read_data() + hdr = Nlmsghdr.from_buffer_copy(self._data) + while hdr.nlmsg_len > len(self._data): + self.read_data() + raw_msg = self._data[: hdr.nlmsg_len] + self._data = self._data[hdr.nlmsg_len:] + return self.parse_message(raw_msg) + + def get_reply(self, tx_msg): + self.write_message(tx_msg) + while True: + rx_msg = self.read_message() + if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: + return rx_msg + + +class NetlinkMultipartIterator(object): + def __init__(self, obj, seq_number: int, msg_type): + self._obj = obj + self._seq = seq_number + self._msg_type = msg_type + + def __iter__(self): + return self + + def __next__(self): + msg = self._obj.read_message() + if self._seq != msg.nl_hdr.nlmsg_seq: + raise ValueError("bad sequence number") + if msg.is_type(NlMsgType.NLMSG_ERROR): + raise ValueError( + "error while handling multipart msg: {}".format(msg.error_code) + ) + elif msg.is_type(NlMsgType.NLMSG_DONE): + if msg.error_code == 0: + raise StopIteration + raise ValueError( + "error listing some parts of the multipart msg: {}".format( + msg.error_code + ) + ) + elif not msg.is_type(self._msg_type): + raise ValueError("bad message type: {}".format(msg)) + return msg + + +class NetlinkTestTemplate(object): + REQUIRED_MODULES = ["netlink"] + + def setup_netlink(self, netlink_family: NlConst): + self.helper = NlHelper() + self.nlsock = Nlsock(netlink_family, self.helper) + + def write_message(self, msg, silent=False): + if not silent: + print("") + print("============= >> TX MESSAGE =============") + msg.print_message() + msg.print_as_bytes(bytes(msg), "-- DATA --") + self.nlsock.write_data(bytes(msg)) + + def read_message(self, silent=False): + msg = self.nlsock.read_message() + if not silent: + print("") + print("============= << RX MESSAGE =============") + msg.print_message() + return msg + + def get_reply(self, tx_msg): + self.write_message(tx_msg) + while True: + rx_msg = self.read_message() + if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: + return rx_msg + + def read_msg_list(self, seq, msg_type): + return list(NetlinkMultipartIterator(self, seq, msg_type)) |