
211 lines
6.9 KiB
Executable File

#!/usr/bin/env nix-shell
#!nix-shell -i python3 -p "python3.withPackages (ps: [ ps.psutil ])"
# vim: set filetype=python :
USAGE: sane-die-with-parent [options...] <cmd> [args ...]
run `cmd` such that when the caller of sane-die-with-parent exits, `cmd` exits as well.
--catch-sigkill: if this process is SIGKILL'd, also forward that to descendants/pgroup.
--descendents: run as a supervisor, and kill every process spawned below this one when the parent dies.
this is useful for running programs which also don't propagate the death signal,
such as `bubblewrap`.
without this, the default is to `exec` <cmd>.
--signal SIGKILL|SIGTERM: control the signal which is sent to child processes.
--use-pgroup: kill children by killing the entire process group instead of walking down the hierarchy.
import ctypes
import logging
import os
import psutil
import signal
import sys
# prctl options:
libc = ctypes.CDLL("")
logger = logging.getLogger(__name__)
# when running as a supervisor, try to exit with the same code as the direct child i spawned.
# that might not always be possible, in which case exit error if *any* child errored, and the direct child is unknown.
def child_exited(exitcode: int, direct: bool = False) -> None:
call when any child exits, to track the exit code this process should return.
global EXIT_CODE
logger.debug(f"child exited with {exitcode} (direct child? {direct})")
if direct:
EXIT_CODE = exitcode
if EXIT_CODE is None and exitcode:
EXIT_CODE = exitcode
def assert_prctl(*args):
rc = libc.prctl(*args)
assert rc == 0, f"prctl({args}) returned unexpected {rc}"
def set_pdeathsig(sig: signal.Signals=signal.SIGTERM):
helper function to ensure once parent process exits, this process will die
see: <>
see: <>
assert_prctl(PR_SET_PDEATHSIG, sig)
def become_reaper():
should any descendent processes become orphaned, reparent them to this process.
this allows me to (in `wait_all_children`) wait for the entire process group
by simply waiting on direct children, until there are no direct children.
see: <>
assert_prctl(PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0)
def kill_all_on_exit(sig: signal.Signals=signal.SIGTERM):
catch when the parent exits, and map that to SIGTERM for this process.
when this process receives SIGTERM, also terminate all descendent processes.
signal.signal(signal.SIGTERM, lambda _sig=None, _frame=None: kill_all_children(sig))
signal.signal(signal.SIGINT, lambda _sig=None, _frame=None: kill_all_children(sig))
PGID = None
def do_spawn(cli: list[str], killsig: signal.Signals) -> None:
run the command as a child, and:
- if the parent exits, then kill all descendents.
- either way, wait for all descendents to exit before returning.
child_pid = os.fork()
if child_pid == 0:
logger.debug("exec as child: %s", cli)
os.execvp(cli[0], cli)
# this process is the parent
_pid, rc = os.waitpid(child_pid, 0)
except ChildProcessError as e:
logger.debug(f"failed to get exit code of direct child: {e}")
child_exited(rc, direct=True)
def wait_all_children() -> None:
while True:
logger.debug("waiting for child...")
(pid, status) = os.wait()
except ChildProcessError as e:
if e.errno == 10:
logger.debug("no more children")
logger.debug(f"child {pid} exited {status}")
def kill_all_children(sig: signal.Signals=signal.SIGTERM) -> None:
if sig != signal.SIGKILL: kill_process_group(sig)
for _ in range(20): # max attempts, arbitrary
children = psutil.Process().children()
logger.debug(f"kill_all_children: {children}")
if children == []:
for child in children:
gone, alive = psutil.wait_procs(children, timeout=0.5)
for p in gone:
def kill_process_group(sig: signal.Signals=signal.SIGTERM) -> None:
global PGID
if PGID is None:
logger.debug("killing process group %d with %d", PGID, sig)
# reset signal handler to avoid recursing
if sig != signal.SIGKILL: signal.signal(sig, signal.Handlers.SIG_IGN)
os.killpg(PGID, sig)
# if sig != signal.SIGKILL: signal.signal(sig, signal.Handlers.SIG_DFL)
def do_exec(cli: list[str], killsig: signal.Signals) -> None:
execute the command inline, but configured so that when my parent exits, this command will exit.
logger.debug("exec: %s", cli)
os.execvp(cli[0], cli)
def main():
global PGID
args = sys.argv[1:]
catch_sigkill = False
descendants = False
killsig_, killsig = None, signal.SIGTERM
use_pgroup = False
verbose = False
while args and args[0].startswith("--"):
flag, args = args[0], args[1:]
if flag == "--catch-sigkill":
catch_sigkill = True
elif flag == "--descendants":
descendants = True
elif flag == "--use-pgroup":
use_pgroup = True
elif flag == "--verbose":
verbose = True
elif flag == "--signal":
killsig_, args = args[0], args[1:]
killsig = getattr(signal, killsig_.upper())
assert False, f"unrecognized argument {flag!r}"
cli = args
if verbose:
if catch_sigkill:
nested_args = [ sys.argv[0] ]
if descendants:
nested_args += [ "--descendants" ]
descendants = True # it's less that we need the outer process to kill its descendants, so much as that it must *exist*
if killsig_:
nested_args += [ "--signal", killsig_ ]
if use_pgroup:
nested_args += [ "--use-pgroup" ]
use_pgroup = False # doesn't make sense for parent to use pgroups
if verbose:
nested_args += [ "--verbose" ]
cli = nested_args + cli
if use_pgroup:
PGID = os.getpid()
# create a new process group, pgid = gid
os.setpgid(PGID, PGID)
if descendants:
do_spawn(cli, killsig)
sys.exit(EXIT_CODE or 0)
do_exec(cli, killsig)
if __name__ == "__main__":