nix-files/pkgs/additional/sane-die-with-parent/sane-die-with-parent

211 lines
6.9 KiB
Python
Executable File

#!/usr/bin/env nix-shell
#!nix-shell -i python3 -p "python3.withPackages (ps: [ ps.psutil ])"
# vim: set filetype=python :
"""
USAGE: sane-die-with-parent [options...] <cmd> [args ...]
run `cmd` such that when the caller of sane-die-with-parent exits, `cmd` exits as well.
OPTIONS:
--catch-sigkill: if this process is SIGKILL'd, also forward that to descendants/pgroup.
--descendents: run as a supervisor, and kill every process spawned below this one when the parent dies.
this is useful for running programs which also don't propagate the death signal,
such as `bubblewrap`.
without this, the default is to `exec` <cmd>.
--signal SIGKILL|SIGTERM: control the signal which is sent to child processes.
--use-pgroup: kill children by killing the entire process group instead of walking down the hierarchy.
--verbose
"""
import ctypes
import logging
import os
import psutil
import signal
import sys
# prctl options:
PR_SET_PDEATHSIG = 1
PR_SET_CHILD_SUBREAPER = 36
libc = ctypes.CDLL("libc.so.6")
logger = logging.getLogger(__name__)
# when running as a supervisor, try to exit with the same code as the direct child i spawned.
# that might not always be possible, in which case exit error if *any* child errored, and the direct child is unknown.
EXIT_CODE = None
def child_exited(exitcode: int, direct: bool = False) -> None:
"""
call when any child exits, to track the exit code this process should return.
"""
global EXIT_CODE
logger.debug(f"child exited with {exitcode} (direct child? {direct})")
if direct:
EXIT_CODE = exitcode
if EXIT_CODE is None and exitcode:
EXIT_CODE = exitcode
def assert_prctl(*args):
rc = libc.prctl(*args)
assert rc == 0, f"prctl({args}) returned unexpected {rc}"
def set_pdeathsig(sig: signal.Signals=signal.SIGTERM):
"""
helper function to ensure once parent process exits, this process will die
see: <https://stackoverflow.com/a/43152455>
see: <https://www.man7.org/linux/man-pages/man2/prctl.2.html>
"""
assert_prctl(PR_SET_PDEATHSIG, sig)
def become_reaper():
"""
should any descendent processes become orphaned, reparent them to this process.
this allows me to (in `wait_all_children`) wait for the entire process group
by simply waiting on direct children, until there are no direct children.
see: <https://jmmv.dev/2019/11/wait-for-process-group-linux.html>
"""
assert_prctl(PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0)
def kill_all_on_exit(sig: signal.Signals=signal.SIGTERM):
"""
catch when the parent exits, and map that to SIGTERM for this process.
when this process receives SIGTERM, also terminate all descendent processes.
"""
set_pdeathsig()
signal.signal(signal.SIGTERM, lambda _sig=None, _frame=None: kill_all_children(sig))
signal.signal(signal.SIGINT, lambda _sig=None, _frame=None: kill_all_children(sig))
PGID = None
def do_spawn(cli: list[str], killsig: signal.Signals) -> None:
"""
run the command as a child, and:
- if the parent exits, then kill all descendents.
- either way, wait for all descendents to exit before returning.
"""
child_pid = os.fork()
if child_pid == 0:
logger.debug("exec as child: %s", cli)
os.execvp(cli[0], cli)
else:
# this process is the parent
kill_all_on_exit(killsig)
try:
_pid, rc = os.waitpid(child_pid, 0)
except ChildProcessError as e:
logger.debug(f"failed to get exit code of direct child: {e}")
else:
child_exited(rc, direct=True)
def wait_all_children() -> None:
while True:
logger.debug("waiting for child...")
try:
(pid, status) = os.wait()
except ChildProcessError as e:
if e.errno == 10:
logger.debug("no more children")
break
else:
logger.debug(f"child {pid} exited {status}")
child_exited(status)
def kill_all_children(sig: signal.Signals=signal.SIGTERM) -> None:
if sig != signal.SIGKILL: kill_process_group(sig)
for _ in range(20): # max attempts, arbitrary
children = psutil.Process().children()
logger.debug(f"kill_all_children: {children}")
if children == []:
break
for child in children:
child.send_signal(sig)
gone, alive = psutil.wait_procs(children, timeout=0.5)
for p in gone:
child_exited(p._exitcode)
kill_process_group(sig)
def kill_process_group(sig: signal.Signals=signal.SIGTERM) -> None:
global PGID
if PGID is None:
return
logger.debug("killing process group %d with %d", PGID, sig)
# reset signal handler to avoid recursing
if sig != signal.SIGKILL: signal.signal(sig, signal.Handlers.SIG_IGN)
os.killpg(PGID, sig)
# if sig != signal.SIGKILL: signal.signal(sig, signal.Handlers.SIG_DFL)
def do_exec(cli: list[str], killsig: signal.Signals) -> None:
"""
execute the command inline, but configured so that when my parent exits, this command will exit.
"""
set_pdeathsig(killsig)
logger.debug("exec: %s", cli)
os.execvp(cli[0], cli)
def main():
global PGID
logging.basicConfig()
args = sys.argv[1:]
catch_sigkill = False
descendants = False
killsig_, killsig = None, signal.SIGTERM
use_pgroup = False
verbose = False
while args and args[0].startswith("--"):
flag, args = args[0], args[1:]
if flag == "--catch-sigkill":
catch_sigkill = True
elif flag == "--descendants":
descendants = True
elif flag == "--use-pgroup":
use_pgroup = True
elif flag == "--verbose":
verbose = True
elif flag == "--signal":
killsig_, args = args[0], args[1:]
killsig = getattr(signal, killsig_.upper())
else:
assert False, f"unrecognized argument {flag!r}"
cli = args
if verbose:
logging.getLogger().setLevel(logging.DEBUG)
if catch_sigkill:
nested_args = [ sys.argv[0] ]
if descendants:
nested_args += [ "--descendants" ]
descendants = True # it's less that we need the outer process to kill its descendants, so much as that it must *exist*
if killsig_:
nested_args += [ "--signal", killsig_ ]
if use_pgroup:
nested_args += [ "--use-pgroup" ]
use_pgroup = False # doesn't make sense for parent to use pgroups
if verbose:
nested_args += [ "--verbose" ]
cli = nested_args + cli
if use_pgroup:
PGID = os.getpid()
# create a new process group, pgid = gid
os.setpgid(PGID, PGID)
if descendants:
become_reaper()
do_spawn(cli, killsig)
kill_all_children()
sys.exit(EXIT_CODE or 0)
else:
do_exec(cli, killsig)
if __name__ == "__main__":
main()