aboutsummaryrefslogtreecommitdiff
path: root/tests/atf_python/sys/net
diff options
context:
space:
mode:
Diffstat (limited to 'tests/atf_python/sys/net')
-rw-r--r--tests/atf_python/sys/net/Makefile11
-rw-r--r--tests/atf_python/sys/net/__init__.py0
-rwxr-xr-xtests/atf_python/sys/net/rtsock.py604
-rw-r--r--tests/atf_python/sys/net/tools.py100
-rw-r--r--tests/atf_python/sys/net/vnet.py559
5 files changed, 1274 insertions, 0 deletions
diff --git a/tests/atf_python/sys/net/Makefile b/tests/atf_python/sys/net/Makefile
new file mode 100644
index 000000000000..70d5b1a3284b
--- /dev/null
+++ b/tests/atf_python/sys/net/Makefile
@@ -0,0 +1,11 @@
+.include <src.opts.mk>
+
+.PATH: ${.CURDIR}
+
+PACKAGE=tests
+FILES= __init__.py rtsock.py tools.py vnet.py
+
+.include <bsd.own.mk>
+FILESDIR= ${TESTSBASE}/atf_python/sys/net
+
+.include <bsd.prog.mk>
diff --git a/tests/atf_python/sys/net/__init__.py b/tests/atf_python/sys/net/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
--- /dev/null
+++ b/tests/atf_python/sys/net/__init__.py
diff --git a/tests/atf_python/sys/net/rtsock.py b/tests/atf_python/sys/net/rtsock.py
new file mode 100755
index 000000000000..788e863f8b28
--- /dev/null
+++ b/tests/atf_python/sys/net/rtsock.py
@@ -0,0 +1,604 @@
+#!/usr/local/bin/python3
+import os
+import socket
+import struct
+import sys
+from ctypes import c_byte
+from ctypes import c_char
+from ctypes import c_int
+from ctypes import c_long
+from ctypes import c_uint32
+from ctypes import c_ulong
+from ctypes import c_ushort
+from ctypes import sizeof
+from ctypes import Structure
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Union
+
+
+def roundup2(val: int, num: int) -> int:
+ if val % num:
+ return (val | (num - 1)) + 1
+ else:
+ return val
+
+
+class RtSockException(OSError):
+ pass
+
+
+class RtConst:
+ RTM_VERSION = 5
+ ALIGN = sizeof(c_long)
+
+ AF_INET = socket.AF_INET
+ AF_INET6 = socket.AF_INET6
+ AF_LINK = socket.AF_LINK
+
+ RTA_DST = 0x1
+ RTA_GATEWAY = 0x2
+ RTA_NETMASK = 0x4
+ RTA_GENMASK = 0x8
+ RTA_IFP = 0x10
+ RTA_IFA = 0x20
+ RTA_AUTHOR = 0x40
+ RTA_BRD = 0x80
+
+ RTM_ADD = 1
+ RTM_DELETE = 2
+ RTM_CHANGE = 3
+ RTM_GET = 4
+
+ RTF_UP = 0x1
+ RTF_GATEWAY = 0x2
+ RTF_HOST = 0x4
+ RTF_REJECT = 0x8
+ RTF_DYNAMIC = 0x10
+ RTF_MODIFIED = 0x20
+ RTF_DONE = 0x40
+ RTF_XRESOLVE = 0x200
+ RTF_LLINFO = 0x400
+ RTF_LLDATA = 0x400
+ RTF_STATIC = 0x800
+ RTF_BLACKHOLE = 0x1000
+ RTF_PROTO2 = 0x4000
+ RTF_PROTO1 = 0x8000
+ RTF_PROTO3 = 0x40000
+ RTF_FIXEDMTU = 0x80000
+ RTF_PINNED = 0x100000
+ RTF_LOCAL = 0x200000
+ RTF_BROADCAST = 0x400000
+ RTF_MULTICAST = 0x800000
+ RTF_STICKY = 0x10000000
+ RTF_RNH_LOCKED = 0x40000000
+ RTF_GWFLAG_COMPAT = 0x80000000
+
+ RTV_MTU = 0x1
+ RTV_HOPCOUNT = 0x2
+ RTV_EXPIRE = 0x4
+ RTV_RPIPE = 0x8
+ RTV_SPIPE = 0x10
+ RTV_SSTHRESH = 0x20
+ RTV_RTT = 0x40
+ RTV_RTTVAR = 0x80
+ RTV_WEIGHT = 0x100
+
+ @staticmethod
+ def get_props(prefix: str) -> List[str]:
+ return [n for n in dir(RtConst) if n.startswith(prefix)]
+
+ @staticmethod
+ def get_name(prefix: str, value: int) -> str:
+ props = RtConst.get_props(prefix)
+ for prop in props:
+ if getattr(RtConst, prop) == value:
+ return prop
+ return "U:{}:{}".format(prefix, value)
+
+ @staticmethod
+ def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]:
+ props = RtConst.get_props(prefix)
+ propmap = {getattr(RtConst, prop): prop for prop in props}
+ v = 1
+ ret = {}
+ while value:
+ if v & value:
+ if v in propmap:
+ ret[v] = propmap[v]
+ else:
+ ret[v] = hex(v)
+ value -= v
+ v *= 2
+ return ret
+
+ @staticmethod
+ def get_bitmask_str(prefix: str, value: int) -> str:
+ bmap = RtConst.get_bitmask_map(prefix, value)
+ return ",".join([v for k, v in bmap.items()])
+
+
+class RtMetrics(Structure):
+ _fields_ = [
+ ("rmx_locks", c_ulong),
+ ("rmx_mtu", c_ulong),
+ ("rmx_hopcount", c_ulong),
+ ("rmx_expire", c_ulong),
+ ("rmx_recvpipe", c_ulong),
+ ("rmx_sendpipe", c_ulong),
+ ("rmx_ssthresh", c_ulong),
+ ("rmx_rtt", c_ulong),
+ ("rmx_rttvar", c_ulong),
+ ("rmx_pksent", c_ulong),
+ ("rmx_weight", c_ulong),
+ ("rmx_nhidx", c_ulong),
+ ("rmx_filler", c_ulong * 2),
+ ]
+
+
+class RtMsgHdr(Structure):
+ _fields_ = [
+ ("rtm_msglen", c_ushort),
+ ("rtm_version", c_byte),
+ ("rtm_type", c_byte),
+ ("rtm_index", c_ushort),
+ ("_rtm_spare1", c_ushort),
+ ("rtm_flags", c_int),
+ ("rtm_addrs", c_int),
+ ("rtm_pid", c_int),
+ ("rtm_seq", c_int),
+ ("rtm_errno", c_int),
+ ("rtm_fmask", c_int),
+ ("rtm_inits", c_ulong),
+ ("rtm_rmx", RtMetrics),
+ ]
+
+
+class SockaddrIn(Structure):
+ _fields_ = [
+ ("sin_len", c_byte),
+ ("sin_family", c_byte),
+ ("sin_port", c_ushort),
+ ("sin_addr", c_uint32),
+ ("sin_zero", c_char * 8),
+ ]
+
+
+class SockaddrIn6(Structure):
+ _fields_ = [
+ ("sin6_len", c_byte),
+ ("sin6_family", c_byte),
+ ("sin6_port", c_ushort),
+ ("sin6_flowinfo", c_uint32),
+ ("sin6_addr", c_byte * 16),
+ ("sin6_scope_id", c_uint32),
+ ]
+
+
+class SockaddrDl(Structure):
+ _fields_ = [
+ ("sdl_len", c_byte),
+ ("sdl_family", c_byte),
+ ("sdl_index", c_ushort),
+ ("sdl_type", c_byte),
+ ("sdl_nlen", c_byte),
+ ("sdl_alen", c_byte),
+ ("sdl_slen", c_byte),
+ ("sdl_data", c_byte * 8),
+ ]
+
+
+class SaHelper(object):
+ @staticmethod
+ def is_ipv6(ip: str) -> bool:
+ return ":" in ip
+
+ @staticmethod
+ def ip_sa(ip: str, scopeid: int = 0) -> bytes:
+ if SaHelper.is_ipv6(ip):
+ return SaHelper.ip6_sa(ip, scopeid)
+ else:
+ return SaHelper.ip4_sa(ip)
+
+ @staticmethod
+ def ip4_sa(ip: str) -> bytes:
+ addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder)
+ sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int)
+ return bytes(sin)
+
+ @staticmethod
+ def ip6_sa(ip6: str, scopeid: int) -> bytes:
+ addr_bytes = (c_byte * 16)()
+ for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)):
+ addr_bytes[i] = b
+ sin6 = SockaddrIn6(
+ sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid
+ )
+ return bytes(sin6)
+
+ @staticmethod
+ def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes:
+ sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype)
+ return bytes(sa)
+
+ @staticmethod
+ def pxlen4_sa(pxlen: int) -> bytes:
+ return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen))
+
+ @staticmethod
+ def pxlen_to_ip4(pxlen: int) -> str:
+ if pxlen == 32:
+ return "255.255.255.255"
+ else:
+ addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1)
+ addr_bytes = struct.pack("!I", addr)
+ return socket.inet_ntop(socket.AF_INET, addr_bytes)
+
+ @staticmethod
+ def pxlen6_sa(pxlen: int) -> bytes:
+ return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen))
+
+ @staticmethod
+ def pxlen_to_ip6(pxlen: int) -> str:
+ ip6_b = [0] * 16
+ start = 0
+ while pxlen > 8:
+ ip6_b[start] = 0xFF
+ pxlen -= 8
+ start += 1
+ ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1)
+ return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b))
+
+ @staticmethod
+ def print_sa_inet(sa: bytes):
+ if len(sa) < 8:
+ raise RtSockException("IPv4 sa size too small: {}".format(len(sa)))
+ addr = socket.inet_ntop(socket.AF_INET, sa[4:8])
+ return "{}".format(addr)
+
+ @staticmethod
+ def print_sa_inet6(sa: bytes):
+ if len(sa) < sizeof(SockaddrIn6):
+ raise RtSockException("IPv6 sa size too small: {}".format(len(sa)))
+ addr = socket.inet_ntop(socket.AF_INET6, sa[8:24])
+ scopeid = struct.unpack(">I", sa[24:28])[0]
+ return "{} scopeid {}".format(addr, scopeid)
+
+ @staticmethod
+ def print_sa_link(sa: bytes, hd: Optional[bool] = True):
+ if len(sa) < sizeof(SockaddrDl):
+ raise RtSockException("LINK sa size too small: {}".format(len(sa)))
+ sdl = SockaddrDl.from_buffer_copy(sa)
+ if sdl.sdl_index:
+ ifindex = "link#{} ".format(sdl.sdl_index)
+ else:
+ ifindex = ""
+ if sdl.sdl_nlen:
+ iface_offset = 8
+ if sdl.sdl_nlen + iface_offset > len(sa):
+ raise RtSockException(
+ "LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa))
+ )
+ ifname = "ifname:{} ".format(
+ bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen])
+ )
+ else:
+ ifname = ""
+ return "{}{}".format(ifindex, ifname)
+
+ @staticmethod
+ def print_sa_unknown(sa: bytes):
+ return "unknown_type:{}".format(sa[1])
+
+ @classmethod
+ def print_sa(cls, sa: bytes, hd: Optional[bool] = False):
+ if sa[0] != len(sa):
+ raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa)))
+
+ if len(sa) < 2:
+ raise Exception(
+ "sa type {} too short: {}".format(
+ RtConst.get_name("AF_", sa[1]), len(sa)
+ )
+ )
+
+ if sa[1] == socket.AF_INET:
+ text = cls.print_sa_inet(sa)
+ elif sa[1] == socket.AF_INET6:
+ text = cls.print_sa_inet6(sa)
+ elif sa[1] == socket.AF_LINK:
+ text = cls.print_sa_link(sa)
+ else:
+ text = cls.print_sa_unknown(sa)
+ if hd:
+ dump = " [{!r}]".format(sa)
+ else:
+ dump = ""
+ return "{}{}".format(text, dump)
+
+
+class BaseRtsockMessage(object):
+ def __init__(self, rtm_type):
+ self.rtm_type = rtm_type
+ self.sa = SaHelper()
+
+ @staticmethod
+ def print_rtm_type(rtm_type):
+ return RtConst.get_name("RTM_", rtm_type)
+
+ @property
+ def rtm_type_str(self):
+ return self.print_rtm_type(self.rtm_type)
+
+
+class RtsockRtMessage(BaseRtsockMessage):
+ messages = [
+ RtConst.RTM_ADD,
+ RtConst.RTM_DELETE,
+ RtConst.RTM_CHANGE,
+ RtConst.RTM_GET,
+ ]
+
+ def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None):
+ super().__init__(rtm_type)
+ self.rtm_flags = 0
+ self.rtm_seq = rtm_seq
+ self._attrs = {}
+ self.rtm_errno = 0
+ self.rtm_pid = 0
+ self.rtm_inits = 0
+ self.rtm_rmx = RtMetrics()
+ self._orig_data = None
+ if dst_sa:
+ self.add_sa_attr(RtConst.RTA_DST, dst_sa)
+ if mask_sa:
+ self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa)
+
+ def add_sa_attr(self, attr_type, attr_bytes: bytes):
+ self._attrs[attr_type] = attr_bytes
+
+ def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0):
+ if ":" in ip_addr:
+ self.add_ip6_attr(attr_type, ip_addr, scopeid)
+ else:
+ self.add_ip4_attr(attr_type, ip_addr)
+
+ def add_ip4_attr(self, attr_type, ip: str):
+ self.add_sa_attr(attr_type, self.sa.ip_sa(ip))
+
+ def add_ip6_attr(self, attr_type, ip6: str, scopeid: int):
+ self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid))
+
+ def add_link_attr(self, attr_type, ifindex: Optional[int] = 0):
+ self.add_sa_attr(attr_type, self.sa.link_sa(ifindex))
+
+ def get_sa(self, attr_type) -> bytes:
+ return self._attrs.get(attr_type)
+
+ def print_message(self):
+ # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC>
+ if self._orig_data:
+ rtm_len = len(self._orig_data)
+ else:
+ rtm_len = len(bytes(self))
+ print(
+ "{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format(
+ self.rtm_type_str,
+ rtm_len,
+ self.rtm_pid,
+ self.rtm_seq,
+ self.rtm_errno,
+ RtConst.get_bitmask_str("RTF_", self.rtm_flags),
+ )
+ )
+ rtm_addrs = sum(list(self._attrs.keys()))
+ print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs)))
+ for attr in sorted(self._attrs.keys()):
+ sa_data = SaHelper.print_sa(self._attrs[attr])
+ print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data))
+
+ def print_in_message(self):
+ print("vvvvvvvv IN vvvvvvvv")
+ self.print_message()
+ print()
+
+ def verify_sa_inet(self, sa_data):
+ if len(sa_data) < 8:
+ raise Exception("IPv4 sa size too small: {}".format(sa_data))
+ if sa_data[0] > len(sa_data):
+ raise Exception(
+ "IPv4 sin_len too big: {} vs sa size {}: {}".format(
+ sa_data[0], len(sa_data), sa_data
+ )
+ )
+ sin = SockaddrIn.from_buffer_copy(sa_data)
+ assert sin.sin_port == 0
+ assert sin.sin_zero == [0] * 8
+
+ def compare_sa(self, sa_type, sa_data):
+ if len(sa_data) < 4:
+ sa_type_name = RtConst.get_name("RTA_", sa_type)
+ raise Exception(
+ "sa_len for type {} too short: {}".format(sa_type_name, len(sa_data))
+ )
+ our_sa = self._attrs[sa_type]
+ assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa)
+ assert len(sa_data) == len(our_sa)
+ assert sa_data == our_sa
+
+ def verify(self, rtm_type: int, rtm_sa):
+ assert self.rtm_type_str == self.print_rtm_type(rtm_type)
+ assert self.rtm_errno == 0
+ hdr = RtMsgHdr.from_buffer_copy(self._orig_data)
+ assert hdr._rtm_spare1 == 0
+ for sa_type, sa_data in rtm_sa.items():
+ if sa_type not in self._attrs:
+ sa_type_name = RtConst.get_name("RTA_", sa_type)
+ raise Exception("SA type {} not present".format(sa_type_name))
+ self.compare_sa(sa_type, sa_data)
+
+ @classmethod
+ def from_bytes(cls, data: bytes):
+ if len(data) < sizeof(RtMsgHdr):
+ raise Exception(
+ "messages size {} is less than expected {}".format(
+ len(data), sizeof(RtMsgHdr)
+ )
+ )
+ hdr = RtMsgHdr.from_buffer_copy(data)
+
+ self = cls(hdr.rtm_type)
+ self.rtm_flags = hdr.rtm_flags
+ self.rtm_seq = hdr.rtm_seq
+ self.rtm_errno = hdr.rtm_errno
+ self.rtm_pid = hdr.rtm_pid
+ self.rtm_inits = hdr.rtm_inits
+ self.rtm_rmx = hdr.rtm_rmx
+ self._orig_data = data
+
+ off = sizeof(RtMsgHdr)
+ v = 1
+ addrs_mask = hdr.rtm_addrs
+ while addrs_mask:
+ if addrs_mask & v:
+ addrs_mask -= v
+
+ if off + data[off] > len(data):
+ raise Exception(
+ "SA sizeof for {} > total message length: {}+{} > {}".format(
+ RtConst.get_name("RTA_", v), off, data[off], len(data)
+ )
+ )
+ self._attrs[v] = data[off : off + data[off]]
+ off += roundup2(data[off], RtConst.ALIGN)
+ v *= 2
+ return self
+
+ def __bytes__(self):
+ sz = sizeof(RtMsgHdr)
+ addrs_mask = 0
+ for k, v in self._attrs.items():
+ sz += roundup2(len(v), RtConst.ALIGN)
+ addrs_mask += k
+ hdr = RtMsgHdr(
+ rtm_msglen=sz,
+ rtm_version=RtConst.RTM_VERSION,
+ rtm_type=self.rtm_type,
+ rtm_flags=self.rtm_flags,
+ rtm_seq=self.rtm_seq,
+ rtm_addrs=addrs_mask,
+ rtm_inits=self.rtm_inits,
+ rtm_rmx=self.rtm_rmx,
+ )
+ buf = bytearray(sz)
+ buf[0 : sizeof(RtMsgHdr)] = hdr
+ off = sizeof(RtMsgHdr)
+ for attr in sorted(self._attrs.keys()):
+ v = self._attrs[attr]
+ sa_len = len(v)
+ buf[off : off + sa_len] = v
+ off += roundup2(len(v), RtConst.ALIGN)
+ return bytes(buf)
+
+
+class Rtsock:
+ def __init__(self):
+ self.socket = self._setup_rtsock()
+ self.rtm_seq = 1
+ self.msgmap = self.build_msgmap()
+
+ def build_msgmap(self):
+ classes = [RtsockRtMessage]
+ xmap = {}
+ for cls in classes:
+ for message in cls.messages:
+ xmap[message] = cls
+ return xmap
+
+ def get_seq(self):
+ ret = self.rtm_seq
+ self.rtm_seq += 1
+ return ret
+
+ def get_weight(self, weight) -> int:
+ if weight:
+ return weight
+ else:
+ return 1 # RT_DEFAULT_WEIGHT
+
+ def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]):
+ px = prefix.split("/")
+ addr_sa = SaHelper.ip_sa(px[0])
+ if len(px) > 1:
+ pxlen = int(px[1])
+ if SaHelper.is_ipv6(px[0]):
+ mask_sa = SaHelper.pxlen6_sa(pxlen)
+ else:
+ mask_sa = SaHelper.pxlen4_sa(pxlen)
+ else:
+ mask_sa = None
+ msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa)
+ if isinstance(gw, bytes):
+ msg.add_sa_attr(RtConst.RTA_GATEWAY, gw)
+ else:
+ # String
+ msg.add_ip_attr(RtConst.RTA_GATEWAY, gw)
+ return msg
+
+ def new_rtm_add(self, prefix: str, gw: Union[str, bytes]):
+ return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw)
+
+ def new_rtm_del(self, prefix: str, gw: Union[str, bytes]):
+ return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw)
+
+ def new_rtm_change(self, prefix: str, gw: Union[str, bytes]):
+ return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw)
+
+ def _setup_rtsock(self) -> socket.socket:
+ s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1)
+ return s
+
+ def print_hd(self, data: bytes):
+ width = 16
+ print("==========================================")
+ for chunk in [data[i : i + width] for i in range(0, len(data), width)]:
+ for b in chunk:
+ print("0x{:02X} ".format(b), end="")
+ print()
+ print()
+
+ def write_message(self, msg):
+ print("vvvvvvvv OUT vvvvvvvv")
+ msg.print_message()
+ print()
+ msg_bytes = bytes(msg)
+ ret = os.write(self.socket.fileno(), msg_bytes)
+ if ret != -1:
+ assert ret == len(msg_bytes)
+
+ def parse_message(self, data: bytes):
+ if len(data) < 4:
+ raise OSError("Short read from rtsock: {} bytes".format(len(data)))
+ rtm_type = data[4]
+ if rtm_type not in self.msgmap:
+ return None
+
+ def write_data(self, data: bytes):
+ self.socket.send(data)
+
+ def read_data(self, seq: Optional[int] = None) -> bytes:
+ while True:
+ data = self.socket.recv(4096)
+ if seq is None:
+ break
+ if len(data) > sizeof(RtMsgHdr):
+ hdr = RtMsgHdr.from_buffer_copy(data)
+ if hdr.rtm_seq == seq:
+ break
+ return data
+
+ def read_message(self) -> bytes:
+ data = self.read_data()
+ return self.parse_message(data)
diff --git a/tests/atf_python/sys/net/tools.py b/tests/atf_python/sys/net/tools.py
new file mode 100644
index 000000000000..44bd74d8578f
--- /dev/null
+++ b/tests/atf_python/sys/net/tools.py
@@ -0,0 +1,100 @@
+#!/usr/local/bin/python3
+import json
+import os
+import subprocess
+
+
+class ToolsHelper(object):
+ NETSTAT_PATH = "/usr/bin/netstat"
+ IFCONFIG_PATH = "/sbin/ifconfig"
+
+ @classmethod
+ def get_output(cls, cmd: str, verbose=False) -> str:
+ if verbose:
+ print("run: '{}'".format(cmd))
+ return os.popen(cmd).read()
+
+ @classmethod
+ def pf_rules(cls, rules, verbose=True):
+ pf_conf = ""
+ for r in rules:
+ pf_conf = pf_conf + r + "\n"
+
+ if verbose:
+ print("Set rules:")
+ print(pf_conf)
+
+ ps = subprocess.Popen("/sbin/pfctl -g -f -", shell=True,
+ stdin=subprocess.PIPE)
+ ps.communicate(bytes(pf_conf, 'utf-8'))
+ ret = ps.wait()
+ if ret != 0:
+ raise Exception("Failed to set pf rules %d" % ret)
+
+ if verbose:
+ cls.print_output("/sbin/pfctl -sr")
+
+ @classmethod
+ def print_output(cls, cmd: str, verbose=True):
+ if verbose:
+ print("======= {} =====".format(cmd))
+ print(cls.get_output(cmd))
+ if verbose:
+ print()
+
+ @classmethod
+ def print_net_debug(cls):
+ cls.print_output("ifconfig")
+ cls.print_output("netstat -rnW")
+
+ @classmethod
+ def set_sysctl(cls, oid, val):
+ cls.get_output("sysctl {}={}".format(oid, val))
+
+ @classmethod
+ def get_routes(cls, family: str, fibnum: int = 0):
+ family_key = {"inet": "-4", "inet6": "-6"}.get(family)
+ out = cls.get_output(
+ "{} {} -rnW -F {} --libxo json".format(cls.NETSTAT_PATH, family_key, fibnum)
+ )
+ js = json.loads(out)
+ js = js["statistics"]["route-information"]["route-table"]["rt-family"]
+ if js:
+ return js[0]["rt-entry"]
+ else:
+ return []
+
+ @classmethod
+ def get_nhops(cls, family: str, fibnum: int = 0):
+ family_key = {"inet": "-4", "inet6": "-6"}.get(family)
+ out = cls.get_output(
+ "{} {} -onW -F {} --libxo json".format(cls.NETSTAT_PATH, family_key, fibnum)
+ )
+ js = json.loads(out)
+ js = js["statistics"]["route-nhop-information"]["nhop-table"]["rt-family"]
+ if js:
+ return js[0]["nh-entry"]
+ else:
+ return []
+
+ @classmethod
+ def get_linklocals(cls):
+ ret = {}
+ ifname = None
+ ips = []
+ for line in cls.get_output(cls.IFCONFIG_PATH).splitlines():
+ if line[0].isalnum():
+ if ifname:
+ ret[ifname] = ips
+ ips = []
+ ifname = line.split(":")[0]
+ else:
+ words = line.split()
+ if words[0] == "inet6" and words[1].startswith("fe80"):
+ # inet6 fe80::1%lo0 prefixlen 64 scopeid 0x2
+ ip = words[1].split("%")[0]
+ scopeid = int(words[words.index("scopeid") + 1], 16)
+ ips.append((ip, scopeid))
+ if ifname:
+ ret[ifname] = ips
+ return ret
diff --git a/tests/atf_python/sys/net/vnet.py b/tests/atf_python/sys/net/vnet.py
new file mode 100644
index 000000000000..f75a3eaa693e
--- /dev/null
+++ b/tests/atf_python/sys/net/vnet.py
@@ -0,0 +1,559 @@
+#!/usr/local/bin/python3
+import copy
+import ipaddress
+import os
+import re
+import socket
+import sys
+import time
+from multiprocessing import connection
+from multiprocessing import Pipe
+from multiprocessing import Process
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+
+from atf_python.sys.net.tools import ToolsHelper
+from atf_python.utils import BaseTest
+from atf_python.utils import libc
+
+
+def run_cmd(cmd: str, verbose=True) -> str:
+ if verbose:
+ print("run: '{}'".format(cmd))
+ return os.popen(cmd).read()
+
+
+def get_topology_id(test_id: str) -> str:
+ """
+ Gets a unique topology id based on the pytest test_id.
+ "test_ip6_output.py::TestIP6Output::test_output6_pktinfo[ipandif]" ->
+ "TestIP6Output:test_output6_pktinfo[ipandif]"
+ """
+ return ":".join(test_id.split("::")[-2:])
+
+
+def convert_test_name(test_name: str) -> str:
+ """Convert test name to a string that can be used in the file/jail names"""
+ ret = ""
+ for char in test_name:
+ if char.isalnum() or char in ("_", "-", ":"):
+ ret += char
+ elif char in ("["):
+ ret += "_"
+ return ret
+
+
+class VnetInterface(object):
+ # defines from net/if_types.h
+ IFT_LOOP = 0x18
+ IFT_ETHER = 0x06
+
+ def __init__(self, iface_alias: str, iface_name: str):
+ self.name = iface_name
+ self.alias = iface_alias
+ self.vnet_name = ""
+ self.jailed = False
+ self.addr_map: Dict[str, Dict] = {"inet6": {}, "inet": {}}
+ self.prefixes4: List[List[str]] = []
+ self.prefixes6: List[List[str]] = []
+ if iface_name.startswith("lo"):
+ self.iftype = self.IFT_LOOP
+ else:
+ self.iftype = self.IFT_ETHER
+ self.ether = ToolsHelper.get_output("/sbin/ifconfig %s ether | awk '/ether/ { print $2; }'" % iface_name).rstrip()
+
+ @property
+ def ifindex(self):
+ return socket.if_nametoindex(self.name)
+
+ @property
+ def first_ipv6(self):
+ d = self.addr_map["inet6"]
+ return d[next(iter(d))]
+
+ @property
+ def first_ipv4(self):
+ d = self.addr_map["inet"]
+ return d[next(iter(d))]
+
+ def set_vnet(self, vnet_name: str):
+ self.vnet_name = vnet_name
+
+ def set_jailed(self, jailed: bool):
+ self.jailed = jailed
+
+ def run_cmd(self, cmd, verbose=False):
+ if self.vnet_name and not self.jailed:
+ cmd = "/usr/sbin/jexec {} {}".format(self.vnet_name, cmd)
+ return run_cmd(cmd, verbose)
+
+ @classmethod
+ def setup_loopback(cls, vnet_name: str):
+ lo = VnetInterface("", "lo0")
+ lo.set_vnet(vnet_name)
+ lo.setup_addr("127.0.0.1/8")
+ lo.turn_up()
+
+ @classmethod
+ def create_iface(cls, alias_name: str, iface_name: str) -> List["VnetInterface"]:
+ name = run_cmd("/sbin/ifconfig {} create".format(iface_name)).rstrip()
+ if not name:
+ raise Exception("Unable to create iface {}".format(iface_name))
+ if1 = cls(alias_name, name)
+ ret = [if1]
+ if name.startswith("epair"):
+ run_cmd("/sbin/ifconfig {} -txcsum -txcsum6".format(name))
+ if2 = cls(alias_name, name[:-1] + "b")
+ if1.epairb = if2
+ ret.append(if2);
+ return ret
+
+ def setup_addr(self, _addr: str):
+ addr = ipaddress.ip_interface(_addr)
+ if addr.version == 6:
+ family = "inet6"
+ cmd = "/sbin/ifconfig {} {} {}".format(self.name, family, addr)
+ else:
+ family = "inet"
+ if self.addr_map[family]:
+ cmd = "/sbin/ifconfig {} alias {}".format(self.name, addr)
+ else:
+ cmd = "/sbin/ifconfig {} {} {}".format(self.name, family, addr)
+ self.run_cmd(cmd)
+ self.addr_map[family][str(addr.ip)] = addr
+
+ def delete_addr(self, _addr: str):
+ addr = ipaddress.ip_address(_addr)
+ if addr.version == 6:
+ family = "inet6"
+ cmd = "/sbin/ifconfig {} inet6 {} delete".format(self.name, addr)
+ else:
+ family = "inet"
+ cmd = "/sbin/ifconfig {} -alias {}".format(self.name, addr)
+ self.run_cmd(cmd)
+ del self.addr_map[family][str(addr)]
+
+ def turn_up(self):
+ cmd = "/sbin/ifconfig {} up".format(self.name)
+ self.run_cmd(cmd)
+
+ def enable_ipv6(self):
+ cmd = "/usr/sbin/ndp -i {} -- -disabled".format(self.name)
+ self.run_cmd(cmd)
+
+ def has_tentative(self) -> bool:
+ """True if an interface has some addresses in tenative state"""
+ cmd = "/sbin/ifconfig {} inet6".format(self.name)
+ out = self.run_cmd(cmd, verbose=False)
+ for line in out.splitlines():
+ if "tentative" in line:
+ return True
+ return False
+
+
+class IfaceFactory(object):
+ INTERFACES_FNAME = "created_ifaces.lst"
+ AUTODELETE_TYPES = ("epair", "gif", "gre", "lo", "tap", "tun")
+
+ def __init__(self):
+ self.file_name = self.INTERFACES_FNAME
+
+ def _register_iface(self, iface_name: str):
+ with open(self.file_name, "a") as f:
+ f.write(iface_name + "\n")
+
+ def _list_ifaces(self) -> List[str]:
+ ret: List[str] = []
+ try:
+ with open(self.file_name, "r") as f:
+ for line in f:
+ ret.append(line.strip())
+ except OSError:
+ pass
+ return ret
+
+ def create_iface(self, alias_name: str, iface_name: str) -> List[VnetInterface]:
+ ifaces = VnetInterface.create_iface(alias_name, iface_name)
+ for iface in ifaces:
+ if not self.is_autodeleted(iface.name):
+ self._register_iface(iface.name)
+ return ifaces
+
+ @staticmethod
+ def is_autodeleted(iface_name: str) -> bool:
+ if iface_name == "lo0":
+ return False
+ iface_type = re.split(r"\d+", iface_name)[0]
+ return iface_type in IfaceFactory.AUTODELETE_TYPES
+
+ def cleanup_vnet_interfaces(self, vnet_name: str) -> List[str]:
+ """Destroys"""
+ ifaces_lst = ToolsHelper.get_output(
+ "/usr/sbin/jexec {} /sbin/ifconfig -l".format(vnet_name)
+ )
+ for iface_name in ifaces_lst.split():
+ if not self.is_autodeleted(iface_name):
+ if iface_name not in self._list_ifaces():
+ print("Skipping interface {}:{}".format(vnet_name, iface_name))
+ continue
+ run_cmd(
+ "/usr/sbin/jexec {} /sbin/ifconfig {} destroy".format(vnet_name, iface_name)
+ )
+
+ def cleanup(self):
+ try:
+ os.unlink(self.INTERFACES_FNAME)
+ except OSError:
+ pass
+
+
+class VnetInstance(object):
+ def __init__(
+ self, vnet_alias: str, vnet_name: str, jid: int, ifaces: List[VnetInterface]
+ ):
+ self.name = vnet_name
+ self.alias = vnet_alias # reference in the test topology
+ self.jid = jid
+ self.ifaces = ifaces
+ self.iface_alias_map = {} # iface.alias: iface
+ self.iface_map = {} # iface.name: iface
+ for iface in ifaces:
+ iface.set_vnet(vnet_name)
+ iface.set_jailed(True)
+ self.iface_alias_map[iface.alias] = iface
+ self.iface_map[iface.name] = iface
+ # Allow reference to interfce aliases as attributes
+ setattr(self, iface.alias, iface)
+ self.need_dad = False # Disable duplicate address detection by default
+ self.attached = False
+ self.pipe = None
+ self.subprocess = None
+
+ def run_vnet_cmd(self, cmd, verbose=True):
+ if not self.attached:
+ cmd = "/usr/sbin/jexec {} {}".format(self.name, cmd)
+ return run_cmd(cmd, verbose)
+
+ def disable_dad(self):
+ self.run_vnet_cmd("/sbin/sysctl net.inet6.ip6.dad_count=0")
+
+ def set_pipe(self, pipe):
+ self.pipe = pipe
+
+ def set_subprocess(self, p):
+ self.subprocess = p
+
+ @staticmethod
+ def attach_jid(jid: int):
+ error_code = libc.jail_attach(jid)
+ if error_code != 0:
+ raise Exception("jail_attach() failed: errno {}".format(error_code))
+
+ def attach(self):
+ self.attach_jid(self.jid)
+ self.attached = True
+
+
+class VnetFactory(object):
+ JAILS_FNAME = "created_jails.lst"
+
+ def __init__(self, topology_id: str):
+ self.topology_id = topology_id
+ self.file_name = self.JAILS_FNAME
+ self._vnets: List[str] = []
+
+ def _register_vnet(self, vnet_name: str):
+ self._vnets.append(vnet_name)
+ with open(self.file_name, "a") as f:
+ f.write(vnet_name + "\n")
+
+ @staticmethod
+ def _wait_interfaces(vnet_name: str, ifaces: List[str]) -> List[str]:
+ cmd = "/usr/sbin/jexec {} /sbin/ifconfig -l".format(vnet_name)
+ not_matched: List[str] = []
+ for i in range(50):
+ vnet_ifaces = run_cmd(cmd).strip().split(" ")
+ not_matched = []
+ for iface_name in ifaces:
+ if iface_name not in vnet_ifaces:
+ not_matched.append(iface_name)
+ if len(not_matched) == 0:
+ return []
+ time.sleep(0.1)
+ return not_matched
+
+ def create_vnet(self, vnet_alias: str, ifaces: List[VnetInterface], opts: List[str]):
+ vnet_name = "pytest:{}".format(convert_test_name(self.topology_id))
+ if self._vnets:
+ # add number to distinguish jails
+ vnet_name = "{}_{}".format(vnet_name, len(self._vnets) + 1)
+ iface_cmds = " ".join(["vnet.interface={}".format(i.name) for i in ifaces])
+ opt_cmds = " ".join(["{}".format(i) for i in opts])
+ cmd = "/usr/sbin/jail -i -c name={} persist vnet {} {}".format(
+ vnet_name, iface_cmds, opt_cmds
+ )
+ jid = 0
+ try:
+ jid_str = run_cmd(cmd)
+ jid = int(jid_str)
+ except ValueError:
+ print("Jail creation failed, output: {}".format(jid_str))
+ raise
+ self._register_vnet(vnet_name)
+
+ # Run expedited version of routing
+ VnetInterface.setup_loopback(vnet_name)
+
+ not_found = self._wait_interfaces(vnet_name, [i.name for i in ifaces])
+ if not_found:
+ raise Exception(
+ "Interfaces {} has not appeared in vnet {}".format(not_found, vnet_name)
+ )
+ return VnetInstance(vnet_alias, vnet_name, jid, ifaces)
+
+ def cleanup(self):
+ iface_factory = IfaceFactory()
+ try:
+ with open(self.file_name) as f:
+ for line in f:
+ vnet_name = line.strip()
+ iface_factory.cleanup_vnet_interfaces(vnet_name)
+ run_cmd("/usr/sbin/jail -r {}".format(vnet_name))
+ os.unlink(self.JAILS_FNAME)
+ except OSError:
+ pass
+
+
+class SingleInterfaceMap(NamedTuple):
+ ifaces: List[VnetInterface]
+ vnet_aliases: List[str]
+
+
+class ObjectsMap(NamedTuple):
+ iface_map: Dict[str, SingleInterfaceMap] # keyed by ifX
+ vnet_map: Dict[str, VnetInstance] # keyed by vnetX
+ topo_map: Dict # self.TOPOLOGY
+
+
+class VnetTestTemplate(BaseTest):
+ NEED_ROOT: bool = True
+ TOPOLOGY = {}
+
+ def _require_default_modules(self):
+ libc.kldload("if_epair.ko")
+ self.require_module("if_epair")
+
+ def _get_vnet_handler(self, vnet_alias: str):
+ handler_name = "{}_handler".format(vnet_alias)
+ return getattr(self, handler_name, None)
+
+ def _setup_vnet(self, vnet: VnetInstance, obj_map: Dict, pipe):
+ """Base Handler to setup given VNET.
+ Can be run in a subprocess. If so, passes control to the special
+ vnetX_handler() after setting up interface addresses
+ """
+ vnet.attach()
+ print("# setup_vnet({})".format(vnet.name))
+ if pipe is not None:
+ vnet.set_pipe(pipe)
+
+ topo = obj_map.topo_map
+ ipv6_ifaces = []
+ # Disable DAD
+ if not vnet.need_dad:
+ vnet.disable_dad()
+ for iface in vnet.ifaces:
+ # check index of vnet within an interface
+ # as we have prefixes for both ends of the interface
+ iface_map = obj_map.iface_map[iface.alias]
+ idx = iface_map.vnet_aliases.index(vnet.alias)
+ prefixes6 = topo[iface.alias].get("prefixes6", [])
+ prefixes4 = topo[iface.alias].get("prefixes4", [])
+ if prefixes6 or prefixes4:
+ ipv6_ifaces.append(iface)
+ iface.turn_up()
+ if prefixes6:
+ iface.enable_ipv6()
+ for prefix in prefixes6 + prefixes4:
+ if prefix[idx]:
+ iface.setup_addr(prefix[idx])
+ for iface in ipv6_ifaces:
+ while iface.has_tentative():
+ time.sleep(0.1)
+
+ # Run actual handler
+ handler = self._get_vnet_handler(vnet.alias)
+ if handler:
+ # Do unbuffered stdout for children
+ # so the logs are present if the child hangs
+ sys.stdout.reconfigure(line_buffering=True)
+ self.drop_privileges()
+ handler(vnet)
+
+ def _get_topo_ifmap(self, topo: Dict):
+ iface_factory = IfaceFactory()
+ iface_map: Dict[str, SingleInterfaceMap] = {}
+ iface_aliases = set()
+ for obj_name, obj_data in topo.items():
+ if obj_name.startswith("vnet"):
+ for iface_alias in obj_data["ifaces"]:
+ iface_aliases.add(iface_alias)
+ for iface_alias in iface_aliases:
+ print("Creating {}".format(iface_alias))
+ iface_data = topo[iface_alias]
+ iface_type = iface_data.get("type", "epair")
+ ifaces = iface_factory.create_iface(iface_alias, iface_type)
+ smap = SingleInterfaceMap(ifaces, [])
+ iface_map[iface_alias] = smap
+ return iface_map
+
+ def setup_topology(self, topo: Dict, topology_id: str):
+ """Creates jails & interfaces for the provided topology"""
+ vnet_map = {}
+ vnet_factory = VnetFactory(topology_id)
+ iface_map = self._get_topo_ifmap(topo)
+ for obj_name, obj_data in topo.items():
+ if obj_name.startswith("vnet"):
+ vnet_ifaces = []
+ for iface_alias in obj_data["ifaces"]:
+ # epair creates 2 interfaces, grab first _available_
+ # and map it to the VNET being created
+ idx = len(iface_map[iface_alias].vnet_aliases)
+ iface_map[iface_alias].vnet_aliases.append(obj_name)
+ vnet_ifaces.append(iface_map[iface_alias].ifaces[idx])
+ opts = []
+ if "opts" in obj_data:
+ opts = obj_data["opts"]
+ vnet = vnet_factory.create_vnet(obj_name, vnet_ifaces, opts)
+ vnet_map[obj_name] = vnet
+ # Allow reference to VNETs as attributes
+ setattr(self, obj_name, vnet)
+ # Debug output
+ print("============= TEST TOPOLOGY =============")
+ for vnet_alias, vnet in vnet_map.items():
+ print("# vnet {} -> {}".format(vnet.alias, vnet.name), end="")
+ handler = self._get_vnet_handler(vnet.alias)
+ if handler:
+ print(" handler: {}".format(handler.__name__), end="")
+ print()
+ for iface_alias, iface_data in iface_map.items():
+ vnets = iface_data.vnet_aliases
+ ifaces: List[VnetInterface] = iface_data.ifaces
+ if len(vnets) == 1 and len(ifaces) == 2:
+ print(
+ "# iface {}: {}::{} -> main::{}".format(
+ iface_alias, vnets[0], ifaces[0].name, ifaces[1].name
+ )
+ )
+ elif len(vnets) == 2 and len(ifaces) == 2:
+ print(
+ "# iface {}: {}::{} -> {}::{}".format(
+ iface_alias, vnets[0], ifaces[0].name, vnets[1], ifaces[1].name
+ )
+ )
+ else:
+ print(
+ "# iface {}: ifaces: {} vnets: {}".format(
+ iface_alias, vnets, [i.name for i in ifaces]
+ )
+ )
+ print()
+ return ObjectsMap(iface_map, vnet_map, topo)
+
+ def setup_method(self, _method):
+ """Sets up all the required topology and handlers for the given test"""
+ super().setup_method(_method)
+ self._require_default_modules()
+
+ # TestIP6Output.test_output6_pktinfo[ipandif]
+ topology_id = get_topology_id(self.test_id)
+ topology = self.TOPOLOGY
+ # First, setup kernel objects - interfaces & vnets
+ obj_map = self.setup_topology(topology, topology_id)
+ main_vnet = None # one without subprocess handler
+ for vnet_alias, vnet in obj_map.vnet_map.items():
+ if self._get_vnet_handler(vnet_alias):
+ # Need subprocess to run
+ parent_pipe, child_pipe = Pipe()
+ p = Process(
+ target=self._setup_vnet,
+ args=(
+ vnet,
+ obj_map,
+ child_pipe,
+ ),
+ )
+ vnet.set_pipe(parent_pipe)
+ vnet.set_subprocess(p)
+ p.start()
+ else:
+ if main_vnet is not None:
+ raise Exception("there can be only 1 VNET w/o handler")
+ main_vnet = vnet
+ # Main vnet needs to be the last, so all the other subprocesses
+ # are started & their pipe handles collected
+ self.vnet = main_vnet
+ self._setup_vnet(main_vnet, obj_map, None)
+ # Save state for the main handler
+ self.iface_map = obj_map.iface_map
+ self.vnet_map = obj_map.vnet_map
+ self.drop_privileges()
+
+ def cleanup(self, test_id: str):
+ # pytest test id: file::class::test_name
+ topology_id = get_topology_id(self.test_id)
+
+ print("============= vnet cleanup =============")
+ print("# topology_id: '{}'".format(topology_id))
+ VnetFactory(topology_id).cleanup()
+ IfaceFactory().cleanup()
+
+ def wait_object(self, pipe, timeout=5):
+ if pipe.poll(timeout):
+ return pipe.recv()
+ raise TimeoutError
+
+ def wait_objects_any(self, pipe_list, timeout=5):
+ objects = connection.wait(pipe_list, timeout)
+ if objects:
+ return objects[0].recv()
+ raise TimeoutError
+
+ def send_object(self, pipe, obj):
+ pipe.send(obj)
+
+ def wait(self):
+ while True:
+ time.sleep(1)
+
+ @property
+ def curvnet(self):
+ pass
+
+
+class SingleVnetTestTemplate(VnetTestTemplate):
+ IPV6_PREFIXES: List[str] = []
+ IPV4_PREFIXES: List[str] = []
+ IFTYPE = "epair"
+
+ def _setup_default_topology(self):
+ topology = copy.deepcopy(
+ {
+ "vnet1": {"ifaces": ["if1"]},
+ "if1": {"type": self.IFTYPE, "prefixes4": [], "prefixes6": []},
+ }
+ )
+ for prefix in self.IPV6_PREFIXES:
+ topology["if1"]["prefixes6"].append((prefix,))
+ for prefix in self.IPV4_PREFIXES:
+ topology["if1"]["prefixes4"].append((prefix,))
+ return topology
+
+ def setup_method(self, method):
+ if not getattr(self, "TOPOLOGY", None):
+ self.TOPOLOGY = self._setup_default_topology()
+ else:
+ names = self.TOPOLOGY.keys()
+ assert len([n for n in names if n.startswith("vnet")]) == 1
+ super().setup_method(method)