aboutsummaryrefslogtreecommitdiff
path: root/tests/atf_python/sys/net
diff options
context:
space:
mode:
authorAlexander V. Chernikov <melifaro@FreeBSD.org>2023-01-07 16:18:39 +0000
committerAlexander V. Chernikov <melifaro@FreeBSD.org>2023-01-08 15:06:34 +0000
commitc1871a3372e382bfcd46452d1d8d4f06561508cc (patch)
tree1e877342205e53194703db646181ad3237e7b6f9 /tests/atf_python/sys/net
parent4ffe60e6833e22f304e63212cc3b3984e3cf643f (diff)
Diffstat (limited to 'tests/atf_python/sys/net')
-rw-r--r--tests/atf_python/sys/net/netlink.py102
1 files changed, 87 insertions, 15 deletions
diff --git a/tests/atf_python/sys/net/netlink.py b/tests/atf_python/sys/net/netlink.py
index 046519ce0343..57c8582627cf 100644
--- a/tests/atf_python/sys/net/netlink.py
+++ b/tests/atf_python/sys/net/netlink.py
@@ -49,6 +49,12 @@ class Nlmsghdr(Structure):
]
+class Nlmsgdone(Structure):
+ _fields_ = [
+ ("error", c_int),
+ ]
+
+
class Nlmsgerr(Structure):
_fields_ = [
("error", c_int),
@@ -961,6 +967,8 @@ rtnl_route_attrs = [
),
]
+nldone_attrs = []
+
nlerr_attrs = [
AttrDescr(NlErrattrType.NLMSGERR_ATTR_MSG, NlAttrStr),
AttrDescr(NlErrattrType.NLMSGERR_ATTR_OFFS, NlAttrU32),
@@ -989,6 +997,7 @@ rtnl_ifla_attrs = [
rtnl_ifa_attrs = [
AttrDescr(IfattrType.IFA_ADDRESS, NlAttrIp),
AttrDescr(IfattrType.IFA_LOCAL, NlAttrIp),
+ AttrDescr(IfattrType.IFA_LABEL, NlAttrStr),
AttrDescr(IfattrType.IFA_BROADCAST, NlAttrIp),
AttrDescr(IfattrType.IFA_ANYCAST, NlAttrIp),
AttrDescr(IfattrType.IFA_FLAGS, NlAttrU32),
@@ -1167,6 +1176,25 @@ class StdNetlinkMessage(BaseNetlinkMessage):
nla.print_attr(" ")
+class NetlinkDoneMessage(StdNetlinkMessage):
+ messages = [NlMsgType.NLMSG_DONE.value]
+ nl_attrs_map = prepare_attrs_map(nldone_attrs)
+
+ @property
+ def error_code(self):
+ return self.base_hdr.error
+
+ def parse_base_header(self, data):
+ if len(data) < sizeof(Nlmsgdone):
+ raise ValueError("length less than nlmsgdone header")
+ done_hdr = Nlmsgdone.from_buffer_copy(data)
+ sz = sizeof(Nlmsgdone)
+ return (done_hdr, sz)
+
+ def print_base_header(self, hdr, prepend=""):
+ print("{}error={}".format(prepend, hdr.error))
+
+
class NetlinkErrorMessage(StdNetlinkMessage):
messages = [NlMsgType.NLMSG_ERROR.value]
nl_attrs_map = prepare_attrs_map(nlerr_attrs)
@@ -1340,6 +1368,7 @@ class Nlsock:
NetlinkRtMessage,
NetlinkIflaMessage,
NetlinkIfaMessage,
+ NetlinkDoneMessage,
NetlinkErrorMessage,
]
xmap = {}
@@ -1476,20 +1505,63 @@ class Nlsock:
self.write_data(msg_bytes)
-def main():
- helper = NlHelper()
- if False:
- nl = Nlsock(NlConst.NETLINK_GENERIC, helper)
- nl.request_families()
- else:
- nl = Nlsock(NlConst.NETLINK_ROUTE, helper)
- # nl.request_ifaddrs(socket.AF_INET)
- # nl.request_raw()
- nl.request_routes(0)
- # nl.request_ifaces()
- while True:
- msg = nl.read_message()
+class NetlinkMultipartIterator(object):
+ def __init__(self, obj, seq_number: int, msg_type):
+ self._obj = obj
+ self._seq = seq_number
+ self._msg_type = msg_type
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ msg = self._obj.read_message()
+ if self._seq != msg.nl_hdr.nlmsg_seq:
+ raise ValueError("bad sequence number")
+ if msg.is_type(NlMsgType.NLMSG_ERROR):
+ raise ValueError(
+ "error while handling multipart msg: {}".format(msg.error_code)
+ )
+ elif msg.is_type(NlMsgType.NLMSG_DONE):
+ if msg.error_code == 0:
+ raise StopIteration
+ raise ValueError(
+ "error listing some parts of the multipart msg: {}".format(
+ msg.error_code
+ )
+ )
+ elif not msg.is_type(self._msg_type):
+ raise ValueError("bad message type: {}".format(msg))
+ return msg
+
+
+class NetlinkTestTemplate(object):
+ REQUIRED_MODULES = ["netlink"]
+
+ def setup_netlink(self, netlink_family: NlConst):
+ self.helper = NlHelper()
+ self.nlsock = Nlsock(netlink_family, self.helper)
+
+ def write_message(self, msg):
print("")
+ print("============= >> TX MESSAGE =============")
msg.print_message()
- msg.print_as_bytes(msg._orig_data, "-- DATA --")
- pass
+ self.nlsock.write_data(bytes(msg))
+ msg.print_as_bytes(bytes(msg), "-- DATA --")
+
+ def read_message(self):
+ msg = self.nlsock.read_message()
+ print("")
+ print("============= << RX MESSAGE =============")
+ msg.print_message()
+ return msg
+
+ def get_reply(self, tx_msg):
+ self.write_message(tx_msg)
+ while True:
+ rx_msg = self.read_message()
+ if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq:
+ return rx_msg
+
+ def read_msg_list(self, seq, msg_type):
+ return list(NetlinkMultipartIterator(self, seq, msg_type))