Simplify test input

This commit is contained in:
2024-12-04 22:08:54 -08:00
parent 80b85b459e
commit 00a3bdd07b
6 changed files with 28 additions and 23 deletions

View File

@@ -1,6 +1,6 @@
from collections import Counter from collections import Counter
test_input_p1 = test_input_p2 = """ test_input = """
3 4 3 4
4 3 4 3
2 5 2 5

View File

@@ -3,7 +3,7 @@ from typing import Iterator
from more_itertools import ilen from more_itertools import ilen
test_input_p1 = test_input_p2 = """ test_input = """
7 6 4 2 1 7 6 4 2 1
1 2 7 8 9 1 2 7 8 9
9 7 6 2 1 9 7 6 2 1

View File

@@ -1,8 +1,12 @@
import re import re
from functools import reduce from functools import reduce
test_input_p1 = "xmul(2,4)%&mul[3,7]!@^do_not_mul(5,5)+mul(32,64]then(mul(11,8)mul(8,5))" test_input_p1 = (
test_input_p2 = "xmul(2,4)&mul[3,7]!^don't()_mul(5,5)+mul(32,64](mul(11,8)undo()?mul(8,5))" "xmul(2,4)%&mul[3,7]!@^do_not_mul(5,5)+mul(32,64]then(mul(11,8)mul(8,5))"
)
test_input_p2 = (
"xmul(2,4)&mul[3,7]!^don't()_mul(5,5)+mul(32,64](mul(11,8)undo()?mul(8,5))"
)
test_solution_p1 = 161 test_solution_p1 = 161
test_solution_p2 = 48 test_solution_p2 = 48

View File

@@ -1,7 +1,7 @@
from typing import Iterator from typing import Iterator
test_input_p1 = test_input_p2 = """ test_input = """
MMMSXXMASM MMMSXXMASM
MSAMXMSMSA MSAMXMSMSA
AMXSXMAAMM AMXSXMAAMM
@@ -39,8 +39,7 @@ def _scan_word_grid(
yield from (row for row in char_grid) yield from (row for row in char_grid)
yield from (tuple(col) for col in zip(*char_grid)) yield from (tuple(col) for col in zip(*char_grid))
yield from ( yield from (
_diagonal(char_grid, i) _diagonal(char_grid, i) for i in range(-len(char_grid) + 1, len(char_grid[0]))
for i in range(-len(char_grid) + 1, len(char_grid[0]))
) )
yield from ( yield from (
_diagonal(tuple(reversed(char_grid)), i) _diagonal(tuple(reversed(char_grid)), i)
@@ -57,28 +56,29 @@ def _diagonal(grid: tuple[tuple[str, ...], ...], offset=0) -> tuple[str, ...]:
def _count_xmas(line: tuple[str, ...]) -> int: def _count_xmas(line: tuple[str, ...]) -> int:
return sum( return sum(
1 1
for i for i in range(len(line) - 3)
in range(len(line) - 3)
if line[i : i + 4] in (("X", "M", "A", "S"), ("S", "A", "M", "X")) if line[i : i + 4] in (("X", "M", "A", "S"), ("S", "A", "M", "X"))
) )
ThreeSquare = tuple[
tuple[str, str, str], ThreeSquare = tuple[tuple[str, str, str], tuple[str, str, str], tuple[str, str, str]]
tuple[str, str, str],
tuple[str, str, str]
]
def _three_squares(word_grid: tuple[tuple[str, ...], ...]) -> Iterator[ThreeSquare]: def _three_squares(word_grid: tuple[tuple[str, ...], ...]) -> Iterator[ThreeSquare]:
yield from ( yield from (
(word_grid[i][j:j+3], word_grid[i+1][j:j+3], word_grid[i+2][j:j+3]) (
word_grid[i][j : j + 3],
word_grid[i + 1][j : j + 3],
word_grid[i + 2][j : j + 3],
)
for i in range(len(word_grid) - 2) for i in range(len(word_grid) - 2)
for j in range(len(word_grid[0]) - 2) for j in range(len(word_grid[0]) - 2)
) )
def _has_cross_mas(square: ThreeSquare): def _has_cross_mas(square: ThreeSquare):
diag_1 = (square[0][0], square[1][1], square[2][2]) diag_1 = (square[0][0], square[1][1], square[2][2])
diag_2 = (square[0][2], square[1][1], square[2][0]) diag_2 = (square[0][2], square[1][1], square[2][0])
return ( return (diag_1 == ("M", "A", "S") or diag_1 == ("S", "A", "M")) and (
(diag_1 == ("M", "A", "S") or diag_1 == ("S", "A", "M")) diag_2 == ("M", "A", "S") or diag_2 == ("S", "A", "M")
and (diag_2 == ("M", "A", "S") or diag_2 == ("S", "A", "M")) )
)

View File

@@ -1,7 +1,7 @@
from typing import Iterator from typing import Iterator
test_input_p1 = """ test_input = """
47|53 47|53
97|13 97|13
97|61 97|61
@@ -31,7 +31,6 @@ test_input_p1 = """
61,13,29 61,13,29
97,13,75,29,47 97,13,75,29,47
""".strip() """.strip()
test_input_p2 = test_input_p1
test_solution_p1 = 143 test_solution_p1 = 143
test_solution_p2 = 123 test_solution_p2 = 123

View File

@@ -31,13 +31,15 @@ def main():
try: try:
print("Testing part 1 solution...") print("Testing part 1 solution...")
test_input_p1 = getattr(mod, "test_input_p1", getattr(mod, "test_input", None))
start = time() start = time()
assert mod.solve_p1(mod.test_input_p1) == mod.test_solution_p1 assert mod.solve_p1(test_input_p1) == mod.test_solution_p1
print(f"Test passed in {time() - start:.3f} seconds") print(f"Test passed in {time() - start:.3f} seconds")
print("Testing part 2 solution...") print("Testing part 2 solution...")
test_input_p2 = getattr(mod, "test_input_p2", getattr(mod, "test_input", None))
start = time() start = time()
assert mod.solve_p2(mod.test_input_p2) == mod.test_solution_p2 assert mod.solve_p2(test_input_p2) == mod.test_solution_p2
print(f"Test passed in {time() - start:.3f} seconds") print(f"Test passed in {time() - start:.3f} seconds")
except AssertionError: except AssertionError: