88 lines
2.6 KiB
Python
88 lines
2.6 KiB
Python
from textwrap import dedent
|
|
from typing import Iterator, override
|
|
from unittest import TestCase
|
|
|
|
from more_itertools import sliding_window
|
|
|
|
from puzzles._solver import Solver
|
|
|
|
|
|
class DayFourSolver(Solver):
|
|
grid: tuple[tuple[str, ...], ...]
|
|
|
|
@override
|
|
def __init__(self, puzzle_input: str):
|
|
self.grid = tuple(tuple(row) for row in puzzle_input.strip().splitlines())
|
|
|
|
@override
|
|
def solve_p1(self) -> int:
|
|
targets = (("X", "M", "A", "S"), ("S", "A", "M", "X"))
|
|
return sum(
|
|
1
|
|
for line in self.scan_lines()
|
|
for window in sliding_window(line, 4)
|
|
if window in targets
|
|
)
|
|
|
|
@override
|
|
def solve_p2(self) -> int:
|
|
targets = (("M", "A", "S"), ("S", "A", "M"))
|
|
return sum(
|
|
1
|
|
for square in self.scan_squares(3)
|
|
if (square[0][0], square[1][1], square[2][2]) in targets
|
|
and (square[0][2], square[1][1], square[2][0]) in targets
|
|
)
|
|
|
|
def scan_lines(self) -> Iterator[tuple[str, ...]]:
|
|
yield from self.grid
|
|
for col in zip(*self.grid):
|
|
yield tuple(col)
|
|
for i in range(-len(self.grid) + 1, len(self.grid[0])):
|
|
yield self.diagonal(i)
|
|
yield self.diagonal(i, inverse=True)
|
|
|
|
def scan_squares(self, size=3) -> Iterator[tuple[tuple[str, ...], ...]]:
|
|
yield from (
|
|
(
|
|
self.grid[i][j : j + size],
|
|
self.grid[i + 1][j : j + size],
|
|
self.grid[i + 2][j : j + size],
|
|
)
|
|
for i in range(len(self.grid) - size + 1)
|
|
for j in range(len(self.grid[0]) - size + 1)
|
|
)
|
|
|
|
def diagonal(self, offset=0, inverse=False) -> tuple[str, ...]:
|
|
return tuple(
|
|
self.grid[i][len(self.grid[0]) - i - offset - 1 if inverse else i + offset]
|
|
for i in range(len(self.grid))
|
|
if 0 <= i + offset < len(self.grid)
|
|
)
|
|
|
|
|
|
class TestDayFourSolver(TestCase):
|
|
def test(self):
|
|
solver = DayFourSolver(
|
|
dedent(
|
|
"""
|
|
MMMSXXMASM
|
|
MSAMXMSMSA
|
|
AMXSXMAAMM
|
|
MSAMASMSMX
|
|
XMASAMXAMM
|
|
XXAMMXXAMA
|
|
SMSMSASXSS
|
|
SAXAMASAAA
|
|
MAMMMXMMMM
|
|
MXMXAXMASX
|
|
"""
|
|
)
|
|
)
|
|
self.assertEqual(
|
|
solver.grid[0],
|
|
("M", "M", "M", "S", "X", "X", "M", "A", "S", "M"),
|
|
)
|
|
self.assertEqual(solver.solve_p1(), 18)
|
|
self.assertEqual(solver.solve_p2(), 9)
|