nixos/test-driver: rm global logger

We remove the global rootlog in favor of instantiating the logger as
required in the __init__.py and pass it down as a parameter (of our
AbstractLogger type).
This commit is contained in:
Stefan Hertrampf 2024-05-07 15:12:38 +02:00
parent 303618c7e1
commit d07866cddc
6 changed files with 64 additions and 44 deletions

View File

@ -6,7 +6,12 @@ from pathlib import Path
import ptpython.repl
from test_driver.driver import Driver
from test_driver.logger import JunitXMLLogger, XMLLogger, rootlog
from test_driver.logger import (
CompositeLogger,
JunitXMLLogger,
TerminalLogger,
XMLLogger,
)
class EnvDefault(argparse.Action):
@ -108,21 +113,23 @@ def main() -> None:
args = arg_parser.parse_args()
output_directory = args.output_directory.resolve()
logger = CompositeLogger([TerminalLogger()])
if "LOGFILE" in os.environ.keys():
rootlog.add_logger(XMLLogger(os.environ["LOGFILE"]))
logger.add_logger(XMLLogger(os.environ["LOGFILE"]))
if args.junit_xml:
rootlog.add_logger(JunitXMLLogger(output_directory / args.junit_xml))
logger.add_logger(JunitXMLLogger(output_directory / args.junit_xml))
if not args.keep_vm_state:
rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")
logger.info("Machine state will be reset. To keep it, pass --keep-vm-state")
with Driver(
args.start_scripts,
args.vlans,
args.testscript.read_text(),
output_directory,
logger,
args.keep_vm_state,
args.global_timeout,
) as driver:
@ -138,7 +145,7 @@ def main() -> None:
tic = time.time()
driver.run_tests()
toc = time.time()
rootlog.info(f"test script finished in {(toc-tic):.2f}s")
logger.info(f"test script finished in {(toc-tic):.2f}s")
def generate_driver_symbols() -> None:
@ -147,7 +154,7 @@ def generate_driver_symbols() -> None:
in user's test scripts. That list is then used by pyflakes to lint those
scripts.
"""
d = Driver([], [], "", Path())
d = Driver([], [], "", Path(), CompositeLogger([]))
test_symbols = d.test_symbols()
with open("driver-symbols", "w") as fp:
fp.write(",".join(test_symbols.keys()))

View File

@ -9,7 +9,7 @@ from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional
from colorama import Fore, Style
from test_driver.logger import rootlog
from test_driver.logger import AbstractLogger
from test_driver.machine import Machine, NixStartScript, retry
from test_driver.polling_condition import PollingCondition
from test_driver.vlan import VLan
@ -49,6 +49,7 @@ class Driver:
polling_conditions: List[PollingCondition]
global_timeout: int
race_timer: threading.Timer
logger: AbstractLogger
def __init__(
self,
@ -56,6 +57,7 @@ class Driver:
vlans: List[int],
tests: str,
out_dir: Path,
logger: AbstractLogger,
keep_vm_state: bool = False,
global_timeout: int = 24 * 60 * 60 * 7,
):
@ -63,12 +65,13 @@ class Driver:
self.out_dir = out_dir
self.global_timeout = global_timeout
self.race_timer = threading.Timer(global_timeout, self.terminate_test)
self.logger = logger
tmp_dir = get_tmp_dir()
with rootlog.nested("start all VLans"):
with self.logger.nested("start all VLans"):
vlans = list(set(vlans))
self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]
def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
for s in scripts:
@ -84,6 +87,7 @@ class Driver:
tmp_dir=tmp_dir,
callbacks=[self.check_polling_conditions],
out_dir=self.out_dir,
logger=self.logger,
)
for cmd in cmd(start_scripts)
]
@ -92,19 +96,19 @@ class Driver:
return self
def __exit__(self, *_: Any) -> None:
with rootlog.nested("cleanup"):
with self.logger.nested("cleanup"):
self.race_timer.cancel()
for machine in self.machines:
machine.release()
def subtest(self, name: str) -> Iterator[None]:
"""Group logs under a given test name"""
with rootlog.subtest(name):
with self.logger.subtest(name):
try:
yield
return True
except Exception as e:
rootlog.error(f'Test "{name}" failed with error: "{e}"')
self.logger.error(f'Test "{name}" failed with error: "{e}"')
raise e
def test_symbols(self) -> Dict[str, Any]:
@ -118,7 +122,7 @@ class Driver:
machines=self.machines,
vlans=self.vlans,
driver=self,
log=rootlog,
log=self.logger,
os=os,
create_machine=self.create_machine,
subtest=subtest,
@ -150,13 +154,13 @@ class Driver:
def test_script(self) -> None:
"""Run the test script"""
with rootlog.nested("run the VM test script"):
with self.logger.nested("run the VM test script"):
symbols = self.test_symbols() # call eagerly
exec(self.tests, symbols, None)
def run_tests(self) -> None:
"""Run the test script (for non-interactive test runs)"""
rootlog.info(
self.logger.info(
f"Test will time out and terminate in {self.global_timeout} seconds"
)
self.race_timer.start()
@ -168,13 +172,13 @@ class Driver:
def start_all(self) -> None:
"""Start all machines"""
with rootlog.nested("start all VMs"):
with self.logger.nested("start all VMs"):
for machine in self.machines:
machine.start()
def join_all(self) -> None:
"""Wait for all machines to shut down"""
with rootlog.nested("wait for all VMs to finish"):
with self.logger.nested("wait for all VMs to finish"):
for machine in self.machines:
machine.wait_for_shutdown()
self.race_timer.cancel()
@ -182,7 +186,7 @@ class Driver:
def terminate_test(self) -> None:
# This will be usually running in another thread than
# the thread actually executing the test script.
with rootlog.nested("timeout reached; test terminating..."):
with self.logger.nested("timeout reached; test terminating..."):
for machine in self.machines:
machine.release()
# As we cannot `sys.exit` from another thread
@ -227,7 +231,7 @@ class Driver:
f"Unsupported arguments passed to create_machine: {args}"
)
rootlog.warning(
self.logger.warning(
Fore.YELLOW
+ Style.BRIGHT
+ "WARNING: Using create_machine with a single dictionary argument is deprecated and will be removed in NixOS 24.11"
@ -246,13 +250,14 @@ class Driver:
start_command=cmd,
name=name,
keep_vm_state=keep_vm_state,
logger=self.logger,
)
def serial_stdout_on(self) -> None:
rootlog.print_serial_logs(True)
self.logger.print_serial_logs(True)
def serial_stdout_off(self) -> None:
rootlog.print_serial_logs(False)
self.logger.print_serial_logs(False)
def check_polling_conditions(self) -> None:
for condition in self.polling_conditions:
@ -271,6 +276,7 @@ class Driver:
def __init__(self, fun: Callable):
self.condition = PollingCondition(
fun,
driver.logger,
seconds_interval,
description,
)
@ -285,15 +291,17 @@ class Driver:
def wait(self, timeout: int = 900) -> None:
def condition(last: bool) -> bool:
if last:
rootlog.info(f"Last chance for {self.condition.description}")
driver.logger.info(
f"Last chance for {self.condition.description}"
)
ret = self.condition.check(force=True)
if not ret and not last:
rootlog.info(
driver.logger.info(
f"({self.condition.description} failure not fatal yet)"
)
return ret
with rootlog.nested(f"waiting for {self.condition.description}"):
with driver.logger.nested(f"waiting for {self.condition.description}"):
retry(condition, timeout=timeout)
if fun_ is None:

View File

@ -307,6 +307,3 @@ class XMLLogger(AbstractLogger):
self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
self.xml.endElement("nest")
rootlog: CompositeLogger = CompositeLogger([TerminalLogger()])

View File

@ -17,7 +17,7 @@ from pathlib import Path
from queue import Queue
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from test_driver.logger import rootlog
from test_driver.logger import AbstractLogger
from .qmp import QMPSession
@ -270,6 +270,7 @@ class Machine:
out_dir: Path,
tmp_dir: Path,
start_command: StartCommand,
logger: AbstractLogger,
name: str = "machine",
keep_vm_state: bool = False,
callbacks: Optional[List[Callable]] = None,
@ -280,6 +281,7 @@ class Machine:
self.name = name
self.start_command = start_command
self.callbacks = callbacks if callbacks is not None else []
self.logger = logger
# set up directories
self.shared_dir = self.tmp_dir / "shared-xchg"
@ -307,15 +309,15 @@ class Machine:
return self.booted and self.connected
def log(self, msg: str) -> None:
rootlog.log(msg, {"machine": self.name})
self.logger.log(msg, {"machine": self.name})
def log_serial(self, msg: str) -> None:
rootlog.log_serial(msg, self.name)
self.logger.log_serial(msg, self.name)
def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
my_attrs = {"machine": self.name}
my_attrs.update(attrs)
return rootlog.nested(msg, my_attrs)
return self.logger.nested(msg, my_attrs)
def wait_for_monitor_prompt(self) -> str:
assert self.monitor is not None
@ -1113,8 +1115,8 @@ class Machine:
def cleanup_statedir(self) -> None:
shutil.rmtree(self.state_dir)
rootlog.log(f"deleting VM state directory {self.state_dir}")
rootlog.log("if you want to keep the VM state, pass --keep-vm-state")
self.logger.log(f"deleting VM state directory {self.state_dir}")
self.logger.log("if you want to keep the VM state, pass --keep-vm-state")
def shutdown(self) -> None:
"""
@ -1221,7 +1223,7 @@ class Machine:
def release(self) -> None:
if self.pid is None:
return
rootlog.info(f"kill machine (pid {self.pid})")
self.logger.info(f"kill machine (pid {self.pid})")
assert self.process
assert self.shell
assert self.monitor

View File

@ -2,7 +2,7 @@ import time
from math import isfinite
from typing import Callable, Optional
from .logger import rootlog
from test_driver.logger import AbstractLogger
class PollingConditionError(Exception):
@ -13,6 +13,7 @@ class PollingCondition:
condition: Callable[[], bool]
seconds_interval: float
description: Optional[str]
logger: AbstractLogger
last_called: float
entry_count: int
@ -20,11 +21,13 @@ class PollingCondition:
def __init__(
self,
condition: Callable[[], Optional[bool]],
logger: AbstractLogger,
seconds_interval: float = 2.0,
description: Optional[str] = None,
):
self.condition = condition # type: ignore
self.seconds_interval = seconds_interval
self.logger = logger
if description is None:
if condition.__doc__:
@ -41,7 +44,7 @@ class PollingCondition:
if (self.entered or not self.overdue) and not force:
return True
with self, rootlog.nested(self.nested_message):
with self, self.logger.nested(self.nested_message):
time_since_last = time.monotonic() - self.last_called
last_message = (
f"Time since last: {time_since_last:.2f}s"
@ -49,13 +52,13 @@ class PollingCondition:
else "(not called yet)"
)
rootlog.info(last_message)
self.logger.info(last_message)
try:
res = self.condition() # type: ignore
except Exception:
res = False
res = res is None or res
rootlog.info(self.status_message(res))
self.logger.info(self.status_message(res))
return res
def maybe_raise(self) -> None:

View File

@ -4,7 +4,7 @@ import pty
import subprocess
from pathlib import Path
from test_driver.logger import rootlog
from test_driver.logger import AbstractLogger
class VLan:
@ -19,17 +19,20 @@ class VLan:
pid: int
fd: io.TextIOBase
logger: AbstractLogger
def __repr__(self) -> str:
return f"<Vlan Nr. {self.nr}>"
def __init__(self, nr: int, tmp_dir: Path):
def __init__(self, nr: int, tmp_dir: Path, logger: AbstractLogger):
self.nr = nr
self.socket_dir = tmp_dir / f"vde{self.nr}.ctl"
self.logger = logger
# TODO: don't side-effect environment here
os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir)
rootlog.info("start vlan")
self.logger.info("start vlan")
pty_master, pty_slave = pty.openpty()
# The --hub is required for the scenario determined by
@ -52,11 +55,11 @@ class VLan:
assert self.process.stdout is not None
self.process.stdout.readline()
if not (self.socket_dir / "ctl").exists():
rootlog.error("cannot start vde_switch")
self.logger.error("cannot start vde_switch")
rootlog.info(f"running vlan (pid {self.pid}; ctl {self.socket_dir})")
self.logger.info(f"running vlan (pid {self.pid}; ctl {self.socket_dir})")
def __del__(self) -> None:
rootlog.info(f"kill vlan (pid {self.pid})")
self.logger.info(f"kill vlan (pid {self.pid})")
self.fd.close()
self.process.terminate()