211 lines
6.9 KiB
Python
Executable File
211 lines
6.9 KiB
Python
Executable File
#!/usr/bin/env nix-shell
|
|
#!nix-shell -i python3 -p "python3.withPackages (ps: [ ps.psutil ])"
|
|
# vim: set filetype=python :
|
|
"""
|
|
USAGE: sane-die-with-parent [options...] <cmd> [args ...]
|
|
|
|
run `cmd` such that when the caller of sane-die-with-parent exits, `cmd` exits as well.
|
|
|
|
OPTIONS:
|
|
--catch-sigkill: if this process is SIGKILL'd, also forward that to descendants/pgroup.
|
|
--descendents: run as a supervisor, and kill every process spawned below this one when the parent dies.
|
|
this is useful for running programs which also don't propagate the death signal,
|
|
such as `bubblewrap`.
|
|
without this, the default is to `exec` <cmd>.
|
|
--signal SIGKILL|SIGTERM: control the signal which is sent to child processes.
|
|
--use-pgroup: kill children by killing the entire process group instead of walking down the hierarchy.
|
|
--verbose
|
|
"""
|
|
|
|
import ctypes
|
|
import logging
|
|
import os
|
|
import psutil
|
|
import signal
|
|
import sys
|
|
|
|
# prctl options:
|
|
PR_SET_PDEATHSIG = 1
|
|
PR_SET_CHILD_SUBREAPER = 36
|
|
|
|
libc = ctypes.CDLL("libc.so.6")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# when running as a supervisor, try to exit with the same code as the direct child i spawned.
|
|
# that might not always be possible, in which case exit error if *any* child errored, and the direct child is unknown.
|
|
EXIT_CODE = None
|
|
def child_exited(exitcode: int, direct: bool = False) -> None:
|
|
"""
|
|
call when any child exits, to track the exit code this process should return.
|
|
"""
|
|
global EXIT_CODE
|
|
logger.debug(f"child exited with {exitcode} (direct child? {direct})")
|
|
if direct:
|
|
EXIT_CODE = exitcode
|
|
if EXIT_CODE is None and exitcode:
|
|
EXIT_CODE = exitcode
|
|
|
|
|
|
def assert_prctl(*args):
|
|
rc = libc.prctl(*args)
|
|
assert rc == 0, f"prctl({args}) returned unexpected {rc}"
|
|
|
|
def set_pdeathsig(sig: signal.Signals=signal.SIGTERM):
|
|
"""
|
|
helper function to ensure once parent process exits, this process will die
|
|
see: <https://stackoverflow.com/a/43152455>
|
|
see: <https://www.man7.org/linux/man-pages/man2/prctl.2.html>
|
|
"""
|
|
assert_prctl(PR_SET_PDEATHSIG, sig)
|
|
|
|
def become_reaper():
|
|
"""
|
|
should any descendent processes become orphaned, reparent them to this process.
|
|
this allows me to (in `wait_all_children`) wait for the entire process group
|
|
by simply waiting on direct children, until there are no direct children.
|
|
see: <https://jmmv.dev/2019/11/wait-for-process-group-linux.html>
|
|
"""
|
|
assert_prctl(PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0)
|
|
|
|
|
|
def kill_all_on_exit(sig: signal.Signals=signal.SIGTERM):
|
|
"""
|
|
catch when the parent exits, and map that to SIGTERM for this process.
|
|
when this process receives SIGTERM, also terminate all descendent processes.
|
|
"""
|
|
set_pdeathsig()
|
|
signal.signal(signal.SIGTERM, lambda _sig=None, _frame=None: kill_all_children(sig))
|
|
signal.signal(signal.SIGINT, lambda _sig=None, _frame=None: kill_all_children(sig))
|
|
|
|
PGID = None
|
|
|
|
def do_spawn(cli: list[str], killsig: signal.Signals) -> None:
|
|
"""
|
|
run the command as a child, and:
|
|
- if the parent exits, then kill all descendents.
|
|
- either way, wait for all descendents to exit before returning.
|
|
"""
|
|
child_pid = os.fork()
|
|
if child_pid == 0:
|
|
logger.debug("exec as child: %s", cli)
|
|
os.execvp(cli[0], cli)
|
|
else:
|
|
# this process is the parent
|
|
kill_all_on_exit(killsig)
|
|
try:
|
|
_pid, rc = os.waitpid(child_pid, 0)
|
|
except ChildProcessError as e:
|
|
logger.debug(f"failed to get exit code of direct child: {e}")
|
|
else:
|
|
child_exited(rc, direct=True)
|
|
|
|
def wait_all_children() -> None:
|
|
while True:
|
|
logger.debug("waiting for child...")
|
|
try:
|
|
(pid, status) = os.wait()
|
|
except ChildProcessError as e:
|
|
if e.errno == 10:
|
|
logger.debug("no more children")
|
|
break
|
|
else:
|
|
logger.debug(f"child {pid} exited {status}")
|
|
child_exited(status)
|
|
|
|
def kill_all_children(sig: signal.Signals=signal.SIGTERM) -> None:
|
|
if sig != signal.SIGKILL: kill_process_group(sig)
|
|
|
|
for _ in range(20): # max attempts, arbitrary
|
|
children = psutil.Process().children()
|
|
logger.debug(f"kill_all_children: {children}")
|
|
if children == []:
|
|
break
|
|
|
|
for child in children:
|
|
child.send_signal(sig)
|
|
gone, alive = psutil.wait_procs(children, timeout=0.5)
|
|
for p in gone:
|
|
child_exited(p._exitcode)
|
|
|
|
kill_process_group(sig)
|
|
|
|
def kill_process_group(sig: signal.Signals=signal.SIGTERM) -> None:
|
|
global PGID
|
|
if PGID is None:
|
|
return
|
|
|
|
logger.debug("killing process group %d with %d", PGID, sig)
|
|
# reset signal handler to avoid recursing
|
|
if sig != signal.SIGKILL: signal.signal(sig, signal.Handlers.SIG_IGN)
|
|
os.killpg(PGID, sig)
|
|
# if sig != signal.SIGKILL: signal.signal(sig, signal.Handlers.SIG_DFL)
|
|
|
|
def do_exec(cli: list[str], killsig: signal.Signals) -> None:
|
|
"""
|
|
execute the command inline, but configured so that when my parent exits, this command will exit.
|
|
"""
|
|
set_pdeathsig(killsig)
|
|
logger.debug("exec: %s", cli)
|
|
os.execvp(cli[0], cli)
|
|
|
|
def main():
|
|
global PGID
|
|
logging.basicConfig()
|
|
|
|
args = sys.argv[1:]
|
|
catch_sigkill = False
|
|
descendants = False
|
|
killsig_, killsig = None, signal.SIGTERM
|
|
use_pgroup = False
|
|
verbose = False
|
|
while args and args[0].startswith("--"):
|
|
flag, args = args[0], args[1:]
|
|
if flag == "--catch-sigkill":
|
|
catch_sigkill = True
|
|
elif flag == "--descendants":
|
|
descendants = True
|
|
elif flag == "--use-pgroup":
|
|
use_pgroup = True
|
|
elif flag == "--verbose":
|
|
verbose = True
|
|
elif flag == "--signal":
|
|
killsig_, args = args[0], args[1:]
|
|
killsig = getattr(signal, killsig_.upper())
|
|
else:
|
|
assert False, f"unrecognized argument {flag!r}"
|
|
cli = args
|
|
|
|
if verbose:
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
|
if catch_sigkill:
|
|
nested_args = [ sys.argv[0] ]
|
|
if descendants:
|
|
nested_args += [ "--descendants" ]
|
|
descendants = True # it's less that we need the outer process to kill its descendants, so much as that it must *exist*
|
|
if killsig_:
|
|
nested_args += [ "--signal", killsig_ ]
|
|
if use_pgroup:
|
|
nested_args += [ "--use-pgroup" ]
|
|
use_pgroup = False # doesn't make sense for parent to use pgroups
|
|
if verbose:
|
|
nested_args += [ "--verbose" ]
|
|
|
|
cli = nested_args + cli
|
|
|
|
if use_pgroup:
|
|
PGID = os.getpid()
|
|
# create a new process group, pgid = gid
|
|
os.setpgid(PGID, PGID)
|
|
|
|
if descendants:
|
|
become_reaper()
|
|
do_spawn(cli, killsig)
|
|
kill_all_children()
|
|
sys.exit(EXIT_CODE or 0)
|
|
else:
|
|
do_exec(cli, killsig)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|