trust-dns: recursor: merge DHCP DNS servers from all non-downed connections

otherwise overwriting the toml configs gets messy, when interfaces come up in unpredictable order
This commit is contained in:
Colin 2024-05-14 08:25:59 +00:00
parent 53198128e8
commit 2a199bf373
3 changed files with 38 additions and 33 deletions

View File

@ -51,7 +51,7 @@
listenAddrs = [ "127.0.0.1" ];
enableRecursiveResolver = true;
# append zones discovered via DHCP to the resolver config.
includes = [ "/var/lib/trust-dns/dhcp-zones.toml" ];
includes = [ "/var/lib/trust-dns/dhcp-configs/*" ];
};
networking.nameservers = [
"127.0.0.1"

View File

@ -37,6 +37,7 @@ let
description = ''
list of paths to cat into the final config.
non-existent paths are skipped.
supports shell-style globbing.
'';
};
enableRecursiveResolver = mkOption {
@ -100,7 +101,7 @@ let
mkdir -p "/var/lib/trust-dns/${flavor}"
${sed} ${subs} -e "" "${configTemplate}" \
| cat - \
${lib.escapeShellArgs includes} \
${lib.concatStringsSep " " includes} \
> "${configPath}" || true
''] ++ lib.mapAttrsToList (zone: { rendered, ... }: ''
${sed} ${subs} -e "" ${pkgs.writeText "${zone}.zone.in" rendered} \

View File

@ -24,14 +24,12 @@ import subprocess
logger = logging.getLogger(__name__)
DNS_DIR = "/var/lib/trust-dns"
DHCP_CONFIGS = "dhcp-configs"
class Ops:
def __init__(self, base_dir: str) -> None:
self.base_dir = base_dir
def read_file(self, path: str) -> str:
return open(os.path.join(self.base_dir, path)).read()
def write_file(self, path: str, contents: str) -> None:
with open(os.path.join(self.base_dir, path), "w") as f:
f.write(contents)
@ -39,10 +37,6 @@ class Ops:
def makedirs(self, path: str) -> None:
os.makedirs(os.path.join(self.base_dir, path), exist_ok=True)
def copy_file(self, from_: str, to_: str) -> None:
contents = self.read_file(from_)
self.write_file(to_, contents)
def exec_(self, cli: list[str]) -> None:
subprocess.check_output(cli)
@ -61,7 +55,7 @@ class NmConfig:
def __init__(self) -> None:
nameservers = os.environ.get("DHCP4_DOMAIN_NAME_SERVERS", "").split(" ")
nameservers = [ns for ns in nameservers if ns]
conn_id = sanitizeName(os.environ.get("CONNECTION_ID", "unknown"))
conn_id = sanitize_name(os.environ.get("CONNECTION_ID", "unknown"))
search_domains = os.environ.get("DHCP4_DOMAIN_SEARCH", "").split(" ")
search_domains = [d for d in search_domains if d]
@ -70,10 +64,10 @@ class NmConfig:
self.search_domains = search_domains
def sanitizeName(name: str) -> str:
def sanitize_name(name: str) -> str:
return "".join(c for c in name if c.lower() in "abcdefghijklmnopqrstuvwxyz0123456789_-")
def isValidSearchDomain(domain: str) -> bool:
def is_valid_search_domain(domain: str) -> bool:
comps = [c for c in domain.split(".") if c]
if len(comps) >= 2:
# allow any search domain that's not a TLD.
@ -82,7 +76,7 @@ def isValidSearchDomain(domain: str) -> bool:
# it's ok to have a search domain of any length -- i'm just hesitant to allow hijacking of very large domain spaces.
return False
def formatZone(domain: str, nameservers: list[str]) -> str:
def format_zone(domain: str, nameservers: list[str]) -> str:
"""
pre-requisites: nameservers is non-empty and no nameserver is "".
domain is the human-friendly domain, trailing dot is optional.
@ -91,7 +85,7 @@ def formatZone(domain: str, nameservers: list[str]) -> str:
if domain[-1] != ".":
domain += "."
lines=(
f'''
f'''\
[[zones]]
zone = "{domain}"
zone_type = "Forward"
@ -105,6 +99,25 @@ stores = {{ type = "forward", name_servers = [
return lines
def apply_zone(nm_config: NmConfig, ops: Ops) -> None:
specialized_config = ""
for domain in nm_config.search_domains:
if is_valid_search_domain(domain) and nm_config.nameservers:
specialized_config += format_zone(domain, nm_config.nameservers) + '\n'
if specialized_config:
# formatting preference: when these configs are `cat`d together, empty-line separators help distinguish
specialized_config = '\n' + specialized_config
# TODO: i'm not sure how this behaves in the presence of multiple interfaces.
# do i want to persist config-per-interface, and then merge them, on every change?
conn_config_path = f"{DHCP_CONFIGS}/{nm_config.conn_id}.toml"
ops.makedirs(DHCP_CONFIGS)
ops.write_file(conn_config_path, specialized_config)
ops.exec_([
"systemctl",
"restart",
"trust-dns-localhost",
])
def main():
logging.basicConfig()
@ -134,26 +147,17 @@ def main():
logger.info(f"sanitized connection id: '{nm_config.conn_id}'")
logger.info(f"search domains: '{' '.join(nm_config.search_domains)}'")
if args.action not in ["dhcp4-change", "dns-change"]:
logger.info(f"action ({args.action}): no handler")
return
if args.action == "down":
logger.info("action: down: clearing DHCP-issued servers")
nm_config.search_domains = []
nm_config.nameservers = []
return apply_zone(nm_config, ops)
elif args.action in ["dhcp4-change", "dns-change", "up"]:
logger.info(f"action: {args.action}: applying DHCP settings")
return apply_zone(nm_config, ops)
else:
logger.info(f"action: {args.action}: no handler")
specializedConfig = ""
for domain in nm_config.search_domains:
if isValidSearchDomain(domain) and nm_config.nameservers:
specializedConfig += "\n" + formatZone(domain, nm_config.nameservers)
# TODO: i'm not sure how this behaves in the presence of multiple interfaces.
# do i want to persist config-per-interface, and then merge them, on every change?
connConfigPath = f"nmhook/{nm_config.conn_id}-dhcp.toml"
ops.makedirs("nmhook")
ops.write_file(connConfigPath, specializedConfig)
ops.copy_file(connConfigPath, "dhcp-zones.toml")
ops.exec_([
"systemctl",
"restart",
"trust-dns-localhost",
])
if __name__ == '__main__':
try: