test-driver: add mypy support

It's a good idea to expand this in future to test code as well,
so we get type checking there as well.
This commit is contained in:
Jörg Thalheim 2019-11-08 10:01:29 +00:00
parent 556a169f14
commit 03e6ca15e2
No known key found for this signature in database
GPG Key ID: B3F5D81B0C6967C4
2 changed files with 103 additions and 91 deletions

View File

@ -1,6 +1,5 @@
#! /somewhere/python3
from contextlib import contextmanager
from contextlib import contextmanager, _GeneratorContextManager
from xml.sax.saxutils import XMLGenerator
import _thread
import atexit
@ -8,7 +7,7 @@ import json
import os
import ptpython.repl
import pty
import queue
from queue import Queue, Empty
import re
import shutil
import socket
@ -17,6 +16,7 @@ import sys
import tempfile
import time
import unicodedata
from typing import Tuple, TextIO, Any, Callable, Dict, Iterator, Optional, List
CHAR_TO_KEY = {
"A": "shift-a",
@ -81,12 +81,18 @@ CHAR_TO_KEY = {
")": "shift-0x0B",
}
# Forward references
nr_tests: int
nr_succeeded: int
log: "Logger"
machines: "List[Machine]"
def eprint(*args, **kwargs):
def eprint(*args: object, **kwargs: Any) -> None:
print(*args, file=sys.stderr, **kwargs)
def create_vlan(vlan_nr):
def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]:
global log
log.log("starting VDE switch for network {}".format(vlan_nr))
vde_socket = os.path.abspath("./vde{}.ctl".format(vlan_nr))
@ -110,7 +116,7 @@ def create_vlan(vlan_nr):
return (vlan_nr, vde_socket, vde_process, fd)
def retry(fn):
def retry(fn: Callable) -> None:
"""Call the given function repeatedly, with 1 second intervals,
until it returns True or a timeout is reached.
"""
@ -125,52 +131,52 @@ def retry(fn):
class Logger:
def __init__(self):
def __init__(self) -> None:
self.logfile = os.environ.get("LOGFILE", "/dev/null")
self.logfile_handle = open(self.logfile, "wb")
self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
self.queue = queue.Queue(1000)
self.queue: "Queue[Dict[str, str]]" = Queue(1000)
self.xml.startDocument()
self.xml.startElement("logfile", attrs={})
def close(self):
def close(self) -> None:
self.xml.endElement("logfile")
self.xml.endDocument()
self.logfile_handle.close()
def sanitise(self, message):
def sanitise(self, message: str) -> str:
return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
def maybe_prefix(self, message, attributes):
def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
if "machine" in attributes:
return "{}: {}".format(attributes["machine"], message)
return message
def log_line(self, message, attributes):
def log_line(self, message: str, attributes: Dict[str, str]) -> None:
self.xml.startElement("line", attributes)
self.xml.characters(message)
self.xml.endElement("line")
def log(self, message, attributes={}):
def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
eprint(self.maybe_prefix(message, attributes))
self.drain_log_queue()
self.log_line(message, attributes)
def enqueue(self, message):
def enqueue(self, message: Dict[str, str]) -> None:
self.queue.put(message)
def drain_log_queue(self):
def drain_log_queue(self) -> None:
try:
while True:
item = self.queue.get_nowait()
attributes = {"machine": item["machine"], "type": "serial"}
self.log_line(self.sanitise(item["msg"]), attributes)
except queue.Empty:
except Empty:
pass
@contextmanager
def nested(self, message, attributes={}):
def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
eprint(self.maybe_prefix(message, attributes))
self.xml.startElement("nest", attrs={})
@ -189,24 +195,22 @@ class Logger:
class Machine:
def __init__(self, args):
def __init__(self, args: Dict[str, Any]) -> None:
if "name" in args:
self.name = args["name"]
else:
self.name = "machine"
try:
cmd = args["startCommand"]
self.name = re.search("run-(.+)-vm$", cmd).group(1)
except KeyError:
pass
except AttributeError:
pass
cmd = args.get("startCommand", None)
if cmd:
match = re.search("run-(.+)-vm$", cmd)
if match:
self.name = match.group(1)
self.script = args.get("startCommand", self.create_startcommand(args))
tmp_dir = os.environ.get("TMPDIR", tempfile.gettempdir())
def create_dir(name):
def create_dir(name: str) -> str:
path = os.path.join(tmp_dir, name)
os.makedirs(path, mode=0o700, exist_ok=True)
return path
@ -216,14 +220,14 @@ class Machine:
self.booted = False
self.connected = False
self.pid = None
self.pid: Optional[int] = None
self.socket = None
self.monitor = None
self.logger = args["log"]
self.monitor: Optional[socket.socket] = None
self.logger: Logger = args["log"]
self.allow_reboot = args.get("allowReboot", False)
@staticmethod
def create_startcommand(args):
def create_startcommand(args: Dict[str, str]) -> str:
net_backend = "-netdev user,id=net0"
net_frontend = "-device virtio-net-pci,netdev=net0"
@ -273,30 +277,32 @@ class Machine:
return start_command
def is_up(self):
def is_up(self) -> bool:
return self.booted and self.connected
def log(self, msg):
def log(self, msg: str) -> None:
self.logger.log(msg, {"machine": self.name})
def nested(self, msg, attrs={}):
def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
my_attrs = {"machine": self.name}
my_attrs.update(attrs)
return self.logger.nested(msg, my_attrs)
def wait_for_monitor_prompt(self):
def wait_for_monitor_prompt(self) -> str:
assert self.monitor is not None
while True:
answer = self.monitor.recv(1024).decode()
if answer.endswith("(qemu) "):
return answer
def send_monitor_command(self, command):
def send_monitor_command(self, command: str) -> str:
message = ("{}\n".format(command)).encode()
self.log("sending monitor command: {}".format(command))
assert self.monitor is not None
self.monitor.send(message)
return self.wait_for_monitor_prompt()
def wait_for_unit(self, unit, user=None):
def wait_for_unit(self, unit: str, user: Optional[str] = None) -> bool:
while True:
info = self.get_unit_info(unit, user)
state = info["ActiveState"]
@ -316,7 +322,7 @@ class Machine:
if state == "active":
return True
def get_unit_info(self, unit, user=None):
def get_unit_info(self, unit: str, user: Optional[str] = None) -> Dict[str, str]:
status, lines = self.systemctl('--no-pager show "{}"'.format(unit), user)
if status != 0:
raise Exception(
@ -327,8 +333,9 @@ class Machine:
line_pattern = re.compile(r"^([^=]+)=(.*)$")
def tuple_from_line(line):
def tuple_from_line(line: str) -> Tuple[str, str]:
match = line_pattern.match(line)
assert match is not None
return match[1], match[2]
return dict(
@ -337,7 +344,7 @@ class Machine:
if line_pattern.match(line)
)
def systemctl(self, q, user=None):
def systemctl(self, q: str, user: Optional[str] = None) -> Tuple[int, str]:
if user is not None:
q = q.replace("'", "\\'")
return self.execute(
@ -349,7 +356,7 @@ class Machine:
)
return self.execute("systemctl {}".format(q))
def require_unit_state(self, unit, require_state="active"):
def require_unit_state(self, unit: str, require_state: str = "active") -> None:
with self.nested(
"checking if unit {} has reached state '{}'".format(unit, require_state)
):
@ -361,7 +368,7 @@ class Machine:
+ "'active' but it is in state {}".format(state)
)
def execute(self, command):
def execute(self, command: str) -> Tuple[int, str]:
self.connect()
out_command = "( {} ); echo '|!EOF' $?\n".format(command)
@ -379,7 +386,7 @@ class Machine:
return (status_code, output)
output += chunk
def succeed(self, *commands):
def succeed(self, *commands: str) -> str:
"""Execute each command and check that it succeeds."""
output = ""
for command in commands:
@ -393,7 +400,7 @@ class Machine:
output += out
return output
def fail(self, *commands):
def fail(self, *commands: str) -> None:
"""Execute each command and check that it fails."""
for command in commands:
with self.nested("must fail: {}".format(command)):
@ -403,21 +410,21 @@ class Machine:
"command `{}` unexpectedly succeeded".format(command)
)
def wait_until_succeeds(self, command):
def wait_until_succeeds(self, command: str) -> str:
with self.nested("waiting for success: {}".format(command)):
while True:
status, output = self.execute(command)
if status == 0:
return output
def wait_until_fails(self, command):
def wait_until_fails(self, command: str) -> str:
with self.nested("waiting for failure: {}".format(command)):
while True:
status, output = self.execute(command)
if status != 0:
return output
def wait_for_shutdown(self):
def wait_for_shutdown(self) -> None:
if not self.booted:
return
@ -429,14 +436,14 @@ class Machine:
self.booted = False
self.connected = False
def get_tty_text(self, tty):
def get_tty_text(self, tty: str) -> str:
status, output = self.execute(
"fold -w$(stty -F /dev/tty{0} size | "
"awk '{{print $2}}') /dev/vcs{0}".format(tty)
)
return output
def wait_until_tty_matches(self, tty, regexp):
def wait_until_tty_matches(self, tty: str, regexp: str) -> bool:
matcher = re.compile(regexp)
with self.nested("waiting for {} to appear on tty {}".format(regexp, tty)):
while True:
@ -444,43 +451,43 @@ class Machine:
if len(matcher.findall(text)) > 0:
return True
def send_chars(self, chars):
def send_chars(self, chars: List[str]) -> None:
with self.nested("sending keys {}".format(chars)):
for char in chars:
self.send_key(char)
def wait_for_file(self, filename):
def wait_for_file(self, filename: str) -> bool:
with self.nested("waiting for file {}".format(filename)):
while True:
status, _ = self.execute("test -e {}".format(filename))
if status == 0:
return True
def wait_for_open_port(self, port):
def port_is_open(_):
def wait_for_open_port(self, port: int) -> None:
def port_is_open(_: Any) -> bool:
status, _ = self.execute("nc -z localhost {}".format(port))
return status == 0
with self.nested("waiting for TCP port {}".format(port)):
retry(port_is_open)
def wait_for_closed_port(self, port):
def port_is_closed(_):
def wait_for_closed_port(self, port: int) -> None:
def port_is_closed(_: Any) -> bool:
status, _ = self.execute("nc -z localhost {}".format(port))
return status != 0
retry(port_is_closed)
def start_job(self, jobname, user=None):
def start_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]:
return self.systemctl("start {}".format(jobname), user)
def stop_job(self, jobname, user=None):
def stop_job(self, jobname: str, user: Optional[str] = None) -> Tuple[int, str]:
return self.systemctl("stop {}".format(jobname), user)
def wait_for_job(self, jobname):
def wait_for_job(self, jobname: str) -> bool:
return self.wait_for_unit(jobname)
def connect(self):
def connect(self) -> None:
if self.connected:
return
@ -496,7 +503,7 @@ class Machine:
self.log("(connecting took {:.2f} seconds)".format(toc - tic))
self.connected = True
def screenshot(self, filename):
def screenshot(self, filename: str) -> None:
out_dir = os.environ.get("out", os.getcwd())
word_pattern = re.compile(r"^\w+$")
if word_pattern.match(filename):
@ -513,12 +520,12 @@ class Machine:
if ret.returncode != 0:
raise Exception("Cannot convert screenshot")
def dump_tty_contents(self, tty):
def dump_tty_contents(self, tty: str) -> None:
"""Debugging: Dump the contents of the TTY<n>
"""
self.execute("fold -w 80 /dev/vcs{} | systemd-cat".format(tty))
def get_screen_text(self):
def get_screen_text(self) -> str:
if shutil.which("tesseract") is None:
raise Exception("get_screen_text used but enableOCR is false")
@ -546,30 +553,30 @@ class Machine:
return ret.stdout.decode("utf-8")
def wait_for_text(self, regex):
def screen_matches(last):
def wait_for_text(self, regex: str) -> None:
def screen_matches(last: bool) -> bool:
text = self.get_screen_text()
m = re.search(regex, text)
matches = re.search(regex, text) is not None
if last and not m:
if last and not matches:
self.log("Last OCR attempt failed. Text was: {}".format(text))
return m
return matches
with self.nested("waiting for {} to appear on screen".format(regex)):
retry(screen_matches)
def send_key(self, key):
def send_key(self, key: str) -> None:
key = CHAR_TO_KEY.get(key, key)
self.send_monitor_command("sendkey {}".format(key))
def start(self):
def start(self) -> None:
if self.booted:
return
self.log("starting vm")
def create_socket(path):
def create_socket(path: str) -> socket.socket:
if os.path.exists(path):
os.unlink(path)
s = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
@ -619,9 +626,9 @@ class Machine:
self.monitor, _ = self.monitor_socket.accept()
self.shell, _ = self.shell_socket.accept()
def process_serial_output():
for line in self.process.stdout:
line = line.decode("unicode_escape").replace("\r", "").rstrip()
def process_serial_output() -> None:
for _line in self.process.stdout:
line = _line.decode("unicode_escape").replace("\r", "").rstrip()
eprint("{} # {}".format(self.name, line))
self.logger.enqueue({"msg": line, "machine": self.name})
@ -634,14 +641,14 @@ class Machine:
self.log("QEMU running (pid {})".format(self.pid))
def shutdown(self):
def shutdown(self) -> None:
if not self.booted:
return
self.shell.send("poweroff\n".encode())
self.wait_for_shutdown()
def crash(self):
def crash(self) -> None:
if not self.booted:
return
@ -649,7 +656,7 @@ class Machine:
self.send_monitor_command("quit")
self.wait_for_shutdown()
def wait_for_x(self):
def wait_for_x(self) -> None:
"""Wait until it is possible to connect to the X server. Note that
testing the existence of /tmp/.X11-unix/X0 is insufficient.
"""
@ -666,15 +673,15 @@ class Machine:
if status == 0:
return
def get_window_names(self):
def get_window_names(self) -> List[str]:
return self.succeed(
r"xwininfo -root -tree | sed 's/.*0x[0-9a-f]* \"\([^\"]*\)\".*/\1/; t; d'"
).splitlines()
def wait_for_window(self, regexp):
def wait_for_window(self, regexp: str) -> None:
pattern = re.compile(regexp)
def window_is_visible(last_try):
def window_is_visible(last_try: bool) -> bool:
names = self.get_window_names()
if last_try:
self.log(
@ -687,10 +694,10 @@ class Machine:
with self.nested("Waiting for a window to appear"):
retry(window_is_visible)
def sleep(self, secs):
def sleep(self, secs: int) -> None:
time.sleep(secs)
def forward_port(self, host_port=8080, guest_port=80):
def forward_port(self, host_port: int = 8080, guest_port: int = 80) -> None:
"""Forward a TCP port on the host to a TCP port on the guest.
Useful during interactive testing.
"""
@ -698,43 +705,46 @@ class Machine:
"hostfwd_add tcp::{}-:{}".format(host_port, guest_port)
)
def block(self):
def block(self) -> None:
"""Make the machine unreachable by shutting down eth1 (the multicast
interface used to talk to the other VMs). We keep eth0 up so that
the test driver can continue to talk to the machine.
"""
self.send_monitor_command("set_link virtio-net-pci.1 off")
def unblock(self):
def unblock(self) -> None:
"""Make the machine reachable.
"""
self.send_monitor_command("set_link virtio-net-pci.1 on")
def create_machine(args):
def create_machine(args: Dict[str, Any]) -> Machine:
global log
args["log"] = log
args["redirectSerial"] = os.environ.get("USE_SERIAL", "0") == "1"
return Machine(args)
def start_all():
def start_all() -> None:
global machines
with log.nested("starting all VMs"):
for machine in machines:
machine.start()
def join_all():
def join_all() -> None:
global machines
with log.nested("waiting for all VMs to finish"):
for machine in machines:
machine.wait_for_shutdown()
def test_script():
def test_script() -> None:
exec(os.environ["testScript"])
def run_tests():
def run_tests() -> None:
global machines
tests = os.environ.get("tests", None)
if tests is not None:
with log.nested("running the VM test script"):
@ -757,7 +767,7 @@ def run_tests():
@contextmanager
def subtest(name):
def subtest(name: str) -> Iterator[None]:
global nr_tests
global nr_succeeded
@ -774,7 +784,6 @@ def subtest(name):
if __name__ == "__main__":
global log
log = Logger()
vlan_nrs = list(dict.fromkeys(os.environ["VLANS"].split()))
@ -793,7 +802,7 @@ if __name__ == "__main__":
nr_succeeded = 0
@atexit.register
def clean_up():
def clean_up() -> None:
with log.nested("cleaning up"):
for machine in machines:
if machine.pid is None:

View File

@ -26,7 +26,7 @@ in rec {
nativeBuildInputs = [ makeWrapper ];
buildInputs = [ (python3.withPackages (p: [ p.ptpython ])) ];
checkInputs = with python3Packages; [ pylint black ];
checkInputs = with python3Packages; [ pylint black mypy ];
dontUnpack = true;
@ -34,6 +34,9 @@ in rec {
doCheck = true;
checkPhase = ''
mypy --disallow-untyped-defs \
--no-implicit-optional \
--ignore-missing-imports ${testDriverScript}
pylint --errors-only ${testDriverScript}
black --check --diff ${testDriverScript}
'';