Files
advent-of-code-2024/puzzles/7.py

110 lines
2.9 KiB
Python

from functools import partial
from math import log
from textwrap import dedent
from typing import Callable, Iterator, NamedTuple, override
from unittest import TestCase
from puzzles._solver import Solver
class Equation(NamedTuple):
target: int
factors: tuple[int, ...]
Operator = Callable[[int, int], int]
def is_calibrated(operators: list[Operator], equation: Equation) -> bool:
def eval_permutations(factors: tuple[int, ...]) -> Iterator[int]:
assert len(factors) > 0
tail = factors[-1]
if len(factors) == 1:
yield tail
return
for head in eval_permutations(factors[:-1]):
if head > equation.target:
continue
for operator in operators:
yield operator(head, tail)
return any(
result == equation.target for result in eval_permutations(equation.factors)
)
def concat_ints(a: int, b: int) -> int:
return 10 ** int(log(b, 10) + 1) * a + b
class TestEquation(TestCase):
def test_is_calibrated(self):
equation = Equation(190, (10, 19))
self.assertTrue(is_calibrated([int.__add__, int.__mul__], equation))
self.assertFalse(is_calibrated([int.__add__], equation))
def test_concat_ints(self):
self.assertEqual(concat_ints(1, 2), 12)
self.assertEqual(concat_ints(12, 345), 12345)
class DaySevenSolver(Solver):
equations: list[Equation]
@override
def __init__(self, puzzle_input: str):
self.equations = []
for line in puzzle_input.strip().split("\n"):
result_string, _, factors_string = line.partition(": ")
self.equations.append(
Equation(
int(result_string),
tuple(map(int, factors_string.split(" "))),
)
)
@override
def solve_p1(self) -> int:
return sum(
equation.target
for equation in filter(
partial(is_calibrated, [int.__add__, int.__mul__]),
self.equations,
)
)
@override
def solve_p2(self) -> int:
return sum(
equation.target
for equation in filter(
partial(is_calibrated, [int.__add__, int.__mul__, concat_ints]),
self.equations,
)
)
class TestDaySevenSolver(TestCase):
def test(self):
solver = DaySevenSolver(
dedent(
"""
190: 10 19
3267: 81 40 27
83: 17 5
156: 15 6
7290: 6 8 6 15
161011: 16 10 13
192: 17 8 14
21037: 9 7 18 13
292: 11 6 16 20
"""
)
)
self.assertEqual(
solver.equations[0],
Equation(190, (10, 19)),
)
self.assertEqual(solver.solve_p1(), 3749)
self.assertEqual(solver.solve_p2(), 11387)