from functools import partial from typing import Callable, Iterator, NamedTuple from math import log test_input = """ 190: 10 19 3267: 81 40 27 83: 17 5 156: 15 6 7290: 6 8 6 15 161011: 16 10 13 192: 17 8 14 21037: 9 7 18 13 292: 11 6 16 20 """.strip() test_solution_p1 = 3749 test_solution_p2 = 11387 def solve_p1(puzzle_input: str) -> int: equations = _parse_equations(puzzle_input) calibration_guard = partial( _is_calibrated, [ int.__add__, int.__mul__, ], ) return sum(equation.target for equation in filter(calibration_guard, equations)) def solve_p2(puzzle_input: str) -> int: equations = _parse_equations(puzzle_input) calibration_guard = partial( _is_calibrated, [ int.__add__, int.__mul__, _concat_ints, ], ) return sum(equation.target for equation in filter(calibration_guard, equations)) class Equation(NamedTuple): target: int factors: tuple[int, ...] def _parse_equations(puzzle_input: str) -> Iterator[Equation]: for line in puzzle_input.split("\n"): if line: result_string, _, factors_string = line.partition(": ") yield Equation( int(result_string), tuple(map(int, factors_string.split(" "))), ) Operator = Callable[[int, int], int] def _is_calibrated(operators: list[Operator], equation: Equation) -> bool: def _eval_permutations(factors: tuple[int, ...]) -> Iterator[int]: assert len(factors) > 0 tail = factors[-1] if len(factors) == 1: yield tail return for head in _eval_permutations(factors[:-1]): if head > equation.target: continue for operator in operators: yield operator(head, tail) return any( result == equation.target for result in _eval_permutations(equation.factors) ) def _concat_ints(a: int, b: int) -> int: return 10 ** int(log(b, 10) + 1) * a + b