nix-files/modules/services/trust-dns/trust-dns-nmhook

169 lines
6.1 KiB
Python
Executable File

#!/usr/bin/env nix-shell
#!nix-shell -i python3 -p "python3.withPackages (ps: [ ])" -p systemd
# vim: set filetype=python :
# /etc/NetworkManager/dispatcher.d/trust-dns-nmhook:
# NetworkManager-dispatcher.service calls this script whenever any network changes state.
# this includes when we activate a new network and receive DHCP info.
# specifically, this script propagates DHCP info to my DNS setup,
# ensuring things like "search domains" work (sorta) with my recursive resolver.
#
# NetworkManager-dispatcher invokes this with env vars related to the action/device/connection. notably:
# - DEVICE_IFACE (e.g. "wlp3s0")
# - DHCP4_DOMAIN_NAME_SERVERS (e.g. "1.1.1.1 4.4.4.4")
# - DHCP4_DOMAIN_SEARCH (e.g. "home.lan uninsane.org")
# - IP4_NAMESERVERS (e.g. "1.1.1.1")
# - CONNECTION_ID (e.g. "my-ssid-name")
# - CONNECTION_FILENAME (e.g. "/etc/NetworkManager/system-connections/XfinityWifi.nmconnection")
import argparse
import logging
import os
import subprocess
logger = logging.getLogger(__name__)
DNS_DIR = "/var/lib/trust-dns"
DHCP_CONFIGS = "dhcp-configs"
class Ops:
def __init__(self, base_dir: str) -> None:
self.base_dir = base_dir
def write_file(self, path: str, contents: str) -> None:
with open(os.path.join(self.base_dir, path), "w") as f:
f.write(contents)
def makedirs(self, path: str) -> None:
os.makedirs(os.path.join(self.base_dir, path), exist_ok=True)
def exec_(self, cli: list[str]) -> None:
subprocess.check_output(cli)
class DryRunOps(Ops):
def write_file(self, path: str, contents: str) -> None:
logger.info(f"dry-run: not writing '{path}'")
logger.debug(contents)
def makedirs(self, path: str) -> None:
logger.info(f"dry-run: not making dirs '{path}'")
def exec_(self, cli: list[str]) -> None:
logger.info(f"dry-run: not `exec`ing: {' '.join(cli)}")
class NmConfig:
def __init__(self) -> None:
nameservers = os.environ.get("DHCP4_DOMAIN_NAME_SERVERS", "").split(" ")
nameservers = [ns for ns in nameservers if ns]
conn_id = sanitize_name(os.environ.get("CONNECTION_ID", "unknown"))
search_domains = os.environ.get("DHCP4_DOMAIN_SEARCH", "").split(" ")
search_domains = [d for d in search_domains if d]
self.nameservers = nameservers
self.conn_id = conn_id
self.search_domains = search_domains
def sanitize_name(name: str) -> str:
return "".join(c for c in name if c.lower() in "abcdefghijklmnopqrstuvwxyz0123456789_-")
def is_valid_search_domain(domain: str) -> bool:
comps = [c for c in domain.split(".") if c]
if len(comps) >= 2:
# allow any search domain that's not a TLD.
return True
logger.warn(f"invalid search domain {domain}") # if you trigger this, then whitelist the search domain here
# it's ok to have a search domain of any length -- i'm just hesitant to allow hijacking of very large domain spaces.
return False
def format_zone(domain: str, nameservers: list[str]) -> str:
"""
pre-requisites: nameservers is non-empty and no nameserver is "".
domain is the human-friendly domain, trailing dot is optional.
"""
assert nameservers, f"no nameservers for zone {domain}"
if domain[-1] != ".":
domain += "."
lines=(
f'''\
[[zones]]
zone = "{domain}"
zone_type = "Forward"
stores = {{ type = "forward", name_servers = [
''')
for i, ns in enumerate(nameservers):
assert ns, "empty nameserver"
if i != 0: lines += ",\n"
lines += f" {{ socket_addr = \"{ns}:53\", protocol = \"udp\", trust_nx_responses = false }}"
lines += '\n]}'
return lines
def apply_zone(nm_config: NmConfig, ops: Ops) -> None:
specialized_config = ""
for domain in nm_config.search_domains:
if is_valid_search_domain(domain) and nm_config.nameservers:
specialized_config += format_zone(domain, nm_config.nameservers) + '\n'
if specialized_config:
# formatting preference: when these configs are `cat`d together, empty-line separators help distinguish
specialized_config = '\n' + specialized_config
conn_config_path = f"{DHCP_CONFIGS}/{nm_config.conn_id}.toml"
ops.makedirs(DHCP_CONFIGS)
ops.write_file(conn_config_path, specialized_config)
# TODO: don't restart if the merged config is expected to be the same;
# restarts are costly, especially since they dump the cache!
ops.exec_([
"systemctl",
"restart",
"trust-dns-localhost",
])
def main():
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
logger.info('invoked')
parser = argparse.ArgumentParser(description='update trust-dns config in response to NetworkManager event')
parser.add_argument('--dns-dir', default=DNS_DIR)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--dry-run', action='store_true')
parser.add_argument('interface')
parser.add_argument('action', help='name of the NetworkManager action this script is responding to')
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
if args.dry_run:
ops = DryRunOps(args.dns_dir)
else:
ops = Ops(args.dns_dir)
nm_config = NmConfig()
logger.info(f"dhcp nameservers: '{' '.join(nm_config.nameservers)}'")
logger.info(f"sanitized connection id: '{nm_config.conn_id}'")
logger.info(f"search domains: '{' '.join(nm_config.search_domains)}'")
if args.action == "down":
logger.info("action: down: clearing DHCP-issued servers")
nm_config.search_domains = []
nm_config.nameservers = []
return apply_zone(nm_config, ops)
elif args.action in ["dhcp4-change", "dns-change", "up"]:
logger.info(f"action: {args.action}: applying DHCP settings")
return apply_zone(nm_config, ops)
else:
logger.info(f"action: {args.action}: no handler")
if __name__ == '__main__':
try:
main()
except Exception as e:
# catch exceptions here and always return `0`, so NetworkManager-dispatcher doesn't abort
logger.info(f"caught exception: {e}")
logging.exception(e)