From 3da79d284e987a5b309103752c4ca3d19b7a180f Mon Sep 17 00:00:00 2001 From: Nettika Date: Sat, 16 Dec 2023 17:40:14 -0800 Subject: [PATCH] Simplify day 3 part 2 solver --- advent_of_code/gears.py | 95 ++++++++++++++++++++++----------------- tests/gears_test.py | 99 +++++++++++++++++++++-------------------- 2 files changed, 103 insertions(+), 91 deletions(-) diff --git a/advent_of_code/gears.py b/advent_of_code/gears.py index 97a316d..58f8550 100644 --- a/advent_of_code/gears.py +++ b/advent_of_code/gears.py @@ -1,45 +1,54 @@ "Day 3: Gear Ratios" from __future__ import annotations +from collections import UserDict from dataclasses import dataclass +from functools import reduce +import math from typing import NamedTuple -class SchematicNumber(NamedTuple): - number: str +class Symbol(NamedTuple): + value: str row: int col: int - def extend_digit(self, digit: str) -> SchematicNumber: - return SchematicNumber(self.number + digit, self.row, self.col) - -class SchematicSymbol(NamedTuple): - symbol: str +class Number(NamedTuple): + value: str row: int col: int + anchor: Symbol | None = None + + def extend_digit(self, digit: str) -> Number: + return self._replace(value=self.value + digit) + + +class Part(NamedTuple): + number: Number + symbol: Symbol @dataclass class Schematic: - numbers: list[SchematicNumber] - symbols: list[SchematicSymbol] + numbers: list[Number] + symbols: list[Symbol] @classmethod def parse(cls, input: str) -> Schematic: row = 0 col = 0 - current_number: SchematicNumber | None = None - numbers: list[SchematicNumber] = [] - symbols: list[SchematicSymbol] = [] + current_number: Number | None = None + numbers: list[Number] = [] + symbols: list[Symbol] = [] for char in input: match char: # Digit case n if n in "0123456789": if not current_number: - current_number = SchematicNumber("", row, col) + current_number = Number("", row, col) current_number = current_number.extend_digit(char) col += 1 @@ -63,7 +72,7 @@ class Schematic: if current_number: numbers.append(current_number) current_number = None - symbols.append(SchematicSymbol(char, row, col)) + symbols.append(Symbol(char, row, col)) col += 1 # Finalize a number at the end of the schematic @@ -72,40 +81,42 @@ class Schematic: return cls(numbers, symbols) - def part_numbers(self) -> list[tuple[SchematicNumber, SchematicSymbol]]: - results = [] + def parts(self) -> set[Part]: + return { + Part(number, symbol) + for number in self.numbers + for symbol in self.symbols + if ( + # Symbol within 1 row + (number.row - 1 <= symbol.row <= number.row + 1) + and + # Symbol within 1 column + (number.col - 1 <= symbol.col <= number.col + len(number.value)) + ) + } - for number in self.numbers: - for symbol in self.symbols: - if ( - # Symbol within 1 row - (number.row - 1 <= symbol.row <= number.row + 1) - and - # Symbol within 1 column - (number.col - 1 <= symbol.col <= number.col + len(number.number)) - ): - results.append((number, symbol)) - break - - return results + def part_groups(self) -> dict[Symbol, set[Part]]: + groups: dict[Symbol, set[Part]] = {} + for part in self.parts(): + if part.symbol not in groups: + groups[part.symbol] = set() + groups[part.symbol].add(part) + return groups def solve_part_1(input: str) -> int: schematic = Schematic.parse(input) - return sum(int(part_number.number) for part_number, _ in schematic.part_numbers()) + return sum(int(part.number.value) for part in schematic.parts()) def solve_part_2(input: str) -> int: schematic = Schematic.parse(input) - part_groups: dict[SchematicSymbol, list[SchematicNumber]] = {} - for part_number, part_symbol in schematic.part_numbers(): - if part_symbol.symbol != "*": - continue - if part_symbol not in part_groups: - part_groups[part_symbol] = [] - part_groups[part_symbol].append(part_number) - gears = [ - part_numbers for part_numbers in part_groups.values() if len(part_numbers) == 2 - ] - - return sum(int(gear_1.number) * int(gear_2.number) for gear_1, gear_2 in gears) + return sum( + reduce( + lambda x, y: x * y, + (int(part.number.value) for part in parts), + ) + for symbol, parts in schematic.part_groups().items() + if symbol.value == "*" + if len(parts) == 2 + ) diff --git a/tests/gears_test.py b/tests/gears_test.py index d234b8d..c32cca2 100644 --- a/tests/gears_test.py +++ b/tests/gears_test.py @@ -1,7 +1,8 @@ from advent_of_code.gears import ( + Part, Schematic, - SchematicNumber, - SchematicSymbol, + Number, + Symbol, solve_part_1, solve_part_2, ) @@ -21,73 +22,73 @@ mock_input = """ def test_schematic_number_extend_digit(): - assert SchematicNumber("", 1, 2).extend_digit("3") == SchematicNumber("3", 1, 2) - assert ( - SchematicNumber("5", 3, 5).extend_digit("0").extend_digit("4") - ) == SchematicNumber("504", 3, 5) + assert Number("", 1, 2).extend_digit("3") == Number("3", 1, 2) + assert (Number("5", 3, 5).extend_digit("0").extend_digit("4")) == Number( + "504", 3, 5 + ) def test_parse_schematic(): assert Schematic.parse(mock_input) == Schematic( [ - SchematicNumber("467", 0, 0), - SchematicNumber("114", 0, 5), - SchematicNumber("35", 2, 2), - SchematicNumber("633", 2, 6), - SchematicNumber("617", 4, 0), - SchematicNumber("58", 5, 8), - SchematicNumber("592", 6, 2), - SchematicNumber("755", 7, 6), - SchematicNumber("664", 9, 1), - SchematicNumber("598", 9, 5), - SchematicNumber("3", 9, 9), + Number("467", 0, 0), + Number("114", 0, 5), + Number("35", 2, 2), + Number("633", 2, 6), + Number("617", 4, 0), + Number("58", 5, 8), + Number("592", 6, 2), + Number("755", 7, 6), + Number("664", 9, 1), + Number("598", 9, 5), + Number("3", 9, 9), ], [ - SchematicSymbol("*", 1, 3), - SchematicSymbol("#", 3, 6), - SchematicSymbol("*", 4, 3), - SchematicSymbol("+", 5, 5), - SchematicSymbol("$", 8, 3), - SchematicSymbol("*", 8, 5), + Symbol("*", 1, 3), + Symbol("#", 3, 6), + Symbol("*", 4, 3), + Symbol("+", 5, 5), + Symbol("$", 8, 3), + Symbol("*", 8, 5), ], ) def test_schematic_part_numbers(): - assert Schematic.parse(mock_input).part_numbers() == [ - ( - SchematicNumber("467", 0, 0), - SchematicSymbol("*", 1, 3), + assert Schematic.parse(mock_input).parts() == { + Part( + Number("467", 0, 0), + Symbol("*", 1, 3), ), - ( - SchematicNumber("35", 2, 2), - SchematicSymbol("*", 1, 3), + Part( + Number("35", 2, 2), + Symbol("*", 1, 3), ), - ( - SchematicNumber("633", 2, 6), - SchematicSymbol("#", 3, 6), + Part( + Number("633", 2, 6), + Symbol("#", 3, 6), ), - ( - SchematicNumber("617", 4, 0), - SchematicSymbol("*", 4, 3), + Part( + Number("617", 4, 0), + Symbol("*", 4, 3), ), - ( - SchematicNumber("592", 6, 2), - SchematicSymbol("+", 5, 5), + Part( + Number("592", 6, 2), + Symbol("+", 5, 5), ), - ( - SchematicNumber("755", 7, 6), - SchematicSymbol("*", 8, 5), + Part( + Number("755", 7, 6), + Symbol("*", 8, 5), ), - ( - SchematicNumber("664", 9, 1), - SchematicSymbol("$", 8, 3), + Part( + Number("664", 9, 1), + Symbol("$", 8, 3), ), - ( - SchematicNumber("598", 9, 5), - SchematicSymbol("*", 8, 5), + Part( + Number("598", 9, 5), + Symbol("*", 8, 5), ), - ] + } def test_solve_part_1():