169 lines
6.1 KiB
Python
Executable File
169 lines
6.1 KiB
Python
Executable File
#!/usr/bin/env nix-shell
|
|
#!nix-shell -i python3 -p "python3.withPackages (ps: [ ])" -p systemd
|
|
# vim: set filetype=python :
|
|
|
|
# /etc/NetworkManager/dispatcher.d/trust-dns-nmhook:
|
|
# NetworkManager-dispatcher.service calls this script whenever any network changes state.
|
|
# this includes when we activate a new network and receive DHCP info.
|
|
# specifically, this script propagates DHCP info to my DNS setup,
|
|
# ensuring things like "search domains" work (sorta) with my recursive resolver.
|
|
#
|
|
# NetworkManager-dispatcher invokes this with env vars related to the action/device/connection. notably:
|
|
# - DEVICE_IFACE (e.g. "wlp3s0")
|
|
# - DHCP4_DOMAIN_NAME_SERVERS (e.g. "1.1.1.1 4.4.4.4")
|
|
# - DHCP4_DOMAIN_SEARCH (e.g. "home.lan uninsane.org")
|
|
# - IP4_NAMESERVERS (e.g. "1.1.1.1")
|
|
# - CONNECTION_ID (e.g. "my-ssid-name")
|
|
# - CONNECTION_FILENAME (e.g. "/etc/NetworkManager/system-connections/XfinityWifi.nmconnection")
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DNS_DIR = "/var/lib/trust-dns"
|
|
DHCP_CONFIGS = "dhcp-configs"
|
|
|
|
class Ops:
|
|
def __init__(self, base_dir: str) -> None:
|
|
self.base_dir = base_dir
|
|
|
|
def write_file(self, path: str, contents: str) -> None:
|
|
with open(os.path.join(self.base_dir, path), "w") as f:
|
|
f.write(contents)
|
|
|
|
def makedirs(self, path: str) -> None:
|
|
os.makedirs(os.path.join(self.base_dir, path), exist_ok=True)
|
|
|
|
def exec_(self, cli: list[str]) -> None:
|
|
subprocess.check_output(cli)
|
|
|
|
class DryRunOps(Ops):
|
|
def write_file(self, path: str, contents: str) -> None:
|
|
logger.info(f"dry-run: not writing '{path}'")
|
|
logger.debug(contents)
|
|
|
|
def makedirs(self, path: str) -> None:
|
|
logger.info(f"dry-run: not making dirs '{path}'")
|
|
|
|
def exec_(self, cli: list[str]) -> None:
|
|
logger.info(f"dry-run: not `exec`ing: {' '.join(cli)}")
|
|
|
|
class NmConfig:
|
|
def __init__(self) -> None:
|
|
nameservers = os.environ.get("DHCP4_DOMAIN_NAME_SERVERS", "").split(" ")
|
|
nameservers = [ns for ns in nameservers if ns]
|
|
conn_id = sanitize_name(os.environ.get("CONNECTION_ID", "unknown"))
|
|
search_domains = os.environ.get("DHCP4_DOMAIN_SEARCH", "").split(" ")
|
|
search_domains = [d for d in search_domains if d]
|
|
|
|
self.nameservers = nameservers
|
|
self.conn_id = conn_id
|
|
self.search_domains = search_domains
|
|
|
|
|
|
def sanitize_name(name: str) -> str:
|
|
return "".join(c for c in name if c.lower() in "abcdefghijklmnopqrstuvwxyz0123456789_-")
|
|
|
|
def is_valid_search_domain(domain: str) -> bool:
|
|
comps = [c for c in domain.split(".") if c]
|
|
if len(comps) >= 2:
|
|
# allow any search domain that's not a TLD.
|
|
return True
|
|
logger.warn(f"invalid search domain {domain}") # if you trigger this, then whitelist the search domain here
|
|
# it's ok to have a search domain of any length -- i'm just hesitant to allow hijacking of very large domain spaces.
|
|
return False
|
|
|
|
def format_zone(domain: str, nameservers: list[str]) -> str:
|
|
"""
|
|
pre-requisites: nameservers is non-empty and no nameserver is "".
|
|
domain is the human-friendly domain, trailing dot is optional.
|
|
"""
|
|
assert nameservers, f"no nameservers for zone {domain}"
|
|
if domain[-1] != ".":
|
|
domain += "."
|
|
lines=(
|
|
f'''\
|
|
[[zones]]
|
|
zone = "{domain}"
|
|
zone_type = "Forward"
|
|
stores = {{ type = "forward", name_servers = [
|
|
''')
|
|
for i, ns in enumerate(nameservers):
|
|
assert ns, "empty nameserver"
|
|
if i != 0: lines += ",\n"
|
|
lines += f" {{ socket_addr = \"{ns}:53\", protocol = \"udp\", trust_nx_responses = false }}"
|
|
lines += '\n]}'
|
|
|
|
return lines
|
|
|
|
def apply_zone(nm_config: NmConfig, ops: Ops) -> None:
|
|
specialized_config = ""
|
|
for domain in nm_config.search_domains:
|
|
if is_valid_search_domain(domain) and nm_config.nameservers:
|
|
specialized_config += format_zone(domain, nm_config.nameservers) + '\n'
|
|
if specialized_config:
|
|
# formatting preference: when these configs are `cat`d together, empty-line separators help distinguish
|
|
specialized_config = '\n' + specialized_config
|
|
|
|
conn_config_path = f"{DHCP_CONFIGS}/{nm_config.conn_id}.toml"
|
|
ops.makedirs(DHCP_CONFIGS)
|
|
ops.write_file(conn_config_path, specialized_config)
|
|
# TODO: don't restart if the merged config is expected to be the same;
|
|
# restarts are costly, especially since they dump the cache!
|
|
ops.exec_([
|
|
"systemctl",
|
|
"restart",
|
|
"trust-dns-localhost",
|
|
])
|
|
|
|
def main():
|
|
logging.basicConfig()
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
|
|
logger.info('invoked')
|
|
|
|
parser = argparse.ArgumentParser(description='update trust-dns config in response to NetworkManager event')
|
|
parser.add_argument('--dns-dir', default=DNS_DIR)
|
|
parser.add_argument('--verbose', action='store_true')
|
|
parser.add_argument('--dry-run', action='store_true')
|
|
parser.add_argument('interface')
|
|
parser.add_argument('action', help='name of the NetworkManager action this script is responding to')
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.verbose:
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
|
if args.dry_run:
|
|
ops = DryRunOps(args.dns_dir)
|
|
else:
|
|
ops = Ops(args.dns_dir)
|
|
|
|
nm_config = NmConfig()
|
|
logger.info(f"dhcp nameservers: '{' '.join(nm_config.nameservers)}'")
|
|
logger.info(f"sanitized connection id: '{nm_config.conn_id}'")
|
|
logger.info(f"search domains: '{' '.join(nm_config.search_domains)}'")
|
|
|
|
if args.action == "down":
|
|
logger.info("action: down: clearing DHCP-issued servers")
|
|
nm_config.search_domains = []
|
|
nm_config.nameservers = []
|
|
return apply_zone(nm_config, ops)
|
|
elif args.action in ["dhcp4-change", "dns-change", "up"]:
|
|
logger.info(f"action: {args.action}: applying DHCP settings")
|
|
return apply_zone(nm_config, ops)
|
|
else:
|
|
logger.info(f"action: {args.action}: no handler")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
try:
|
|
main()
|
|
except Exception as e:
|
|
# catch exceptions here and always return `0`, so NetworkManager-dispatcher doesn't abort
|
|
logger.info(f"caught exception: {e}")
|
|
logging.exception(e)
|