diff --git a/puzzles/4.py b/puzzles/4.py new file mode 100644 index 0000000..d50be5b --- /dev/null +++ b/puzzles/4.py @@ -0,0 +1,84 @@ +from typing import Iterator + + +test_input_p1 = test_input_p2 = """ +MMMSXXMASM +MSAMXMSMSA +AMXSXMAAMM +MSAMASMSMX +XMASAMXAMM +XXAMMXXAMA +SMSMSASXSS +SAXAMASAAA +MAMMMXMMMM +MXMXAXMASX +""".strip() + +test_solution_p1 = 18 +test_solution_p2 = 9 + + +def solve_p1(puzzle_input: str) -> int: + word_grid = _parse_world_grid(puzzle_input) + return sum(_count_xmas(line) for line in _scan_word_grid(word_grid)) + + +def solve_p2(puzzle_input: str) -> int: + word_grid = _parse_world_grid(puzzle_input) + squares = _three_squares(word_grid) + return sum(1 for square in squares if _has_cross_mas(square)) + + +def _parse_world_grid(puzzle_input: str) -> tuple[tuple[str, ...], ...]: + return tuple(tuple(row) for row in puzzle_input.splitlines()) + + +def _scan_word_grid( + char_grid: tuple[tuple[str, ...], ...] +) -> Iterator[tuple[str, ...]]: + yield from (row for row in char_grid) + yield from (tuple(col) for col in zip(*char_grid)) + yield from ( + _diagonal(char_grid, i) + for i in range(-len(char_grid) + 1, len(char_grid[0])) + ) + yield from ( + _diagonal(tuple(reversed(char_grid)), i) + for i in range(-len(char_grid) + 1, len(char_grid[0])) + ) + + +def _diagonal(grid: tuple[tuple[str, ...], ...], offset=0) -> tuple[str, ...]: + return tuple( + grid[i][i + offset] for i in range(len(grid)) if 0 <= i + offset < len(grid) + ) + + +def _count_xmas(line: tuple[str, ...]) -> int: + return sum( + 1 + for i + in range(len(line) - 3) + if line[i : i + 4] in (("X", "M", "A", "S"), ("S", "A", "M", "X")) + ) + +ThreeSquare = tuple[ + tuple[str, str, str], + tuple[str, str, str], + tuple[str, str, str] +] + +def _three_squares(word_grid: tuple[tuple[str, ...], ...]) -> Iterator[ThreeSquare]: + yield from ( + (word_grid[i][j:j+3], word_grid[i+1][j:j+3], word_grid[i+2][j:j+3]) + for i in range(len(word_grid) - 2) + for j in range(len(word_grid[0]) - 2) + ) + +def _has_cross_mas(square: ThreeSquare): + diag_1 = (square[0][0], square[1][1], square[2][2]) + diag_2 = (square[0][2], square[1][1], square[2][0]) + return ( + (diag_1 == ("M", "A", "S") or diag_1 == ("S", "A", "M")) + and (diag_2 == ("M", "A", "S") or diag_2 == ("S", "A", "M")) + ) \ No newline at end of file