from typing import Iterator test_input = """ MMMSXXMASM MSAMXMSMSA AMXSXMAAMM MSAMASMSMX XMASAMXAMM XXAMMXXAMA SMSMSASXSS SAXAMASAAA MAMMMXMMMM MXMXAXMASX """.strip() test_solution_p1 = 18 test_solution_p2 = 9 def solve_p1(puzzle_input: str) -> int: word_grid = _parse_world_grid(puzzle_input) return sum(_count_xmas(line) for line in _scan_word_grid(word_grid)) def solve_p2(puzzle_input: str) -> int: word_grid = _parse_world_grid(puzzle_input) squares = _three_squares(word_grid) return sum(1 for square in squares if _has_cross_mas(square)) def _parse_world_grid(puzzle_input: str) -> tuple[tuple[str, ...], ...]: return tuple(tuple(row) for row in puzzle_input.splitlines()) def _scan_word_grid( char_grid: tuple[tuple[str, ...], ...] ) -> Iterator[tuple[str, ...]]: yield from (row for row in char_grid) yield from (tuple(col) for col in zip(*char_grid)) yield from ( _diagonal(char_grid, i) for i in range(-len(char_grid) + 1, len(char_grid[0])) ) yield from ( _diagonal(tuple(reversed(char_grid)), i) for i in range(-len(char_grid) + 1, len(char_grid[0])) ) def _diagonal(grid: tuple[tuple[str, ...], ...], offset=0) -> tuple[str, ...]: return tuple( grid[i][i + offset] for i in range(len(grid)) if 0 <= i + offset < len(grid) ) def _count_xmas(line: tuple[str, ...]) -> int: return sum( 1 for i in range(len(line) - 3) if line[i : i + 4] in (("X", "M", "A", "S"), ("S", "A", "M", "X")) ) ThreeSquare = tuple[tuple[str, str, str], tuple[str, str, str], tuple[str, str, str]] def _three_squares(word_grid: tuple[tuple[str, ...], ...]) -> Iterator[ThreeSquare]: yield from ( ( word_grid[i][j : j + 3], word_grid[i + 1][j : j + 3], word_grid[i + 2][j : j + 3], ) for i in range(len(word_grid) - 2) for j in range(len(word_grid[0]) - 2) ) def _has_cross_mas(square: ThreeSquare): diag_1 = (square[0][0], square[1][1], square[2][2]) diag_2 = (square[0][2], square[1][1], square[2][0]) return (diag_1 == ("M", "A", "S") or diag_1 == ("S", "A", "M")) and ( diag_2 == ("M", "A", "S") or diag_2 == ("S", "A", "M") )