Simplify day 3 part 2 solver

This commit is contained in:
Nettika
2023-12-16 17:40:14 -08:00
parent 54271e0a2c
commit 3da79d284e
2 changed files with 103 additions and 91 deletions

View File

@@ -1,45 +1,54 @@
"Day 3: Gear Ratios" "Day 3: Gear Ratios"
from __future__ import annotations from __future__ import annotations
from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce
import math
from typing import NamedTuple from typing import NamedTuple
class SchematicNumber(NamedTuple): class Symbol(NamedTuple):
number: str value: str
row: int row: int
col: int col: int
def extend_digit(self, digit: str) -> SchematicNumber:
return SchematicNumber(self.number + digit, self.row, self.col)
class Number(NamedTuple):
class SchematicSymbol(NamedTuple): value: str
symbol: str
row: int row: int
col: int col: int
anchor: Symbol | None = None
def extend_digit(self, digit: str) -> Number:
return self._replace(value=self.value + digit)
class Part(NamedTuple):
number: Number
symbol: Symbol
@dataclass @dataclass
class Schematic: class Schematic:
numbers: list[SchematicNumber] numbers: list[Number]
symbols: list[SchematicSymbol] symbols: list[Symbol]
@classmethod @classmethod
def parse(cls, input: str) -> Schematic: def parse(cls, input: str) -> Schematic:
row = 0 row = 0
col = 0 col = 0
current_number: SchematicNumber | None = None current_number: Number | None = None
numbers: list[SchematicNumber] = [] numbers: list[Number] = []
symbols: list[SchematicSymbol] = [] symbols: list[Symbol] = []
for char in input: for char in input:
match char: match char:
# Digit # Digit
case n if n in "0123456789": case n if n in "0123456789":
if not current_number: if not current_number:
current_number = SchematicNumber("", row, col) current_number = Number("", row, col)
current_number = current_number.extend_digit(char) current_number = current_number.extend_digit(char)
col += 1 col += 1
@@ -63,7 +72,7 @@ class Schematic:
if current_number: if current_number:
numbers.append(current_number) numbers.append(current_number)
current_number = None current_number = None
symbols.append(SchematicSymbol(char, row, col)) symbols.append(Symbol(char, row, col))
col += 1 col += 1
# Finalize a number at the end of the schematic # Finalize a number at the end of the schematic
@@ -72,40 +81,42 @@ class Schematic:
return cls(numbers, symbols) return cls(numbers, symbols)
def part_numbers(self) -> list[tuple[SchematicNumber, SchematicSymbol]]: def parts(self) -> set[Part]:
results = [] return {
Part(number, symbol)
for number in self.numbers
for symbol in self.symbols
if (
# Symbol within 1 row
(number.row - 1 <= symbol.row <= number.row + 1)
and
# Symbol within 1 column
(number.col - 1 <= symbol.col <= number.col + len(number.value))
)
}
for number in self.numbers: def part_groups(self) -> dict[Symbol, set[Part]]:
for symbol in self.symbols: groups: dict[Symbol, set[Part]] = {}
if ( for part in self.parts():
# Symbol within 1 row if part.symbol not in groups:
(number.row - 1 <= symbol.row <= number.row + 1) groups[part.symbol] = set()
and groups[part.symbol].add(part)
# Symbol within 1 column return groups
(number.col - 1 <= symbol.col <= number.col + len(number.number))
):
results.append((number, symbol))
break
return results
def solve_part_1(input: str) -> int: def solve_part_1(input: str) -> int:
schematic = Schematic.parse(input) schematic = Schematic.parse(input)
return sum(int(part_number.number) for part_number, _ in schematic.part_numbers()) return sum(int(part.number.value) for part in schematic.parts())
def solve_part_2(input: str) -> int: def solve_part_2(input: str) -> int:
schematic = Schematic.parse(input) schematic = Schematic.parse(input)
part_groups: dict[SchematicSymbol, list[SchematicNumber]] = {} return sum(
for part_number, part_symbol in schematic.part_numbers(): reduce(
if part_symbol.symbol != "*": lambda x, y: x * y,
continue (int(part.number.value) for part in parts),
if part_symbol not in part_groups: )
part_groups[part_symbol] = [] for symbol, parts in schematic.part_groups().items()
part_groups[part_symbol].append(part_number) if symbol.value == "*"
gears = [ if len(parts) == 2
part_numbers for part_numbers in part_groups.values() if len(part_numbers) == 2 )
]
return sum(int(gear_1.number) * int(gear_2.number) for gear_1, gear_2 in gears)

View File

@@ -1,7 +1,8 @@
from advent_of_code.gears import ( from advent_of_code.gears import (
Part,
Schematic, Schematic,
SchematicNumber, Number,
SchematicSymbol, Symbol,
solve_part_1, solve_part_1,
solve_part_2, solve_part_2,
) )
@@ -21,73 +22,73 @@ mock_input = """
def test_schematic_number_extend_digit(): def test_schematic_number_extend_digit():
assert SchematicNumber("", 1, 2).extend_digit("3") == SchematicNumber("3", 1, 2) assert Number("", 1, 2).extend_digit("3") == Number("3", 1, 2)
assert ( assert (Number("5", 3, 5).extend_digit("0").extend_digit("4")) == Number(
SchematicNumber("5", 3, 5).extend_digit("0").extend_digit("4") "504", 3, 5
) == SchematicNumber("504", 3, 5) )
def test_parse_schematic(): def test_parse_schematic():
assert Schematic.parse(mock_input) == Schematic( assert Schematic.parse(mock_input) == Schematic(
[ [
SchematicNumber("467", 0, 0), Number("467", 0, 0),
SchematicNumber("114", 0, 5), Number("114", 0, 5),
SchematicNumber("35", 2, 2), Number("35", 2, 2),
SchematicNumber("633", 2, 6), Number("633", 2, 6),
SchematicNumber("617", 4, 0), Number("617", 4, 0),
SchematicNumber("58", 5, 8), Number("58", 5, 8),
SchematicNumber("592", 6, 2), Number("592", 6, 2),
SchematicNumber("755", 7, 6), Number("755", 7, 6),
SchematicNumber("664", 9, 1), Number("664", 9, 1),
SchematicNumber("598", 9, 5), Number("598", 9, 5),
SchematicNumber("3", 9, 9), Number("3", 9, 9),
], ],
[ [
SchematicSymbol("*", 1, 3), Symbol("*", 1, 3),
SchematicSymbol("#", 3, 6), Symbol("#", 3, 6),
SchematicSymbol("*", 4, 3), Symbol("*", 4, 3),
SchematicSymbol("+", 5, 5), Symbol("+", 5, 5),
SchematicSymbol("$", 8, 3), Symbol("$", 8, 3),
SchematicSymbol("*", 8, 5), Symbol("*", 8, 5),
], ],
) )
def test_schematic_part_numbers(): def test_schematic_part_numbers():
assert Schematic.parse(mock_input).part_numbers() == [ assert Schematic.parse(mock_input).parts() == {
( Part(
SchematicNumber("467", 0, 0), Number("467", 0, 0),
SchematicSymbol("*", 1, 3), Symbol("*", 1, 3),
), ),
( Part(
SchematicNumber("35", 2, 2), Number("35", 2, 2),
SchematicSymbol("*", 1, 3), Symbol("*", 1, 3),
), ),
( Part(
SchematicNumber("633", 2, 6), Number("633", 2, 6),
SchematicSymbol("#", 3, 6), Symbol("#", 3, 6),
), ),
( Part(
SchematicNumber("617", 4, 0), Number("617", 4, 0),
SchematicSymbol("*", 4, 3), Symbol("*", 4, 3),
), ),
( Part(
SchematicNumber("592", 6, 2), Number("592", 6, 2),
SchematicSymbol("+", 5, 5), Symbol("+", 5, 5),
), ),
( Part(
SchematicNumber("755", 7, 6), Number("755", 7, 6),
SchematicSymbol("*", 8, 5), Symbol("*", 8, 5),
), ),
( Part(
SchematicNumber("664", 9, 1), Number("664", 9, 1),
SchematicSymbol("$", 8, 3), Symbol("$", 8, 3),
), ),
( Part(
SchematicNumber("598", 9, 5), Number("598", 9, 5),
SchematicSymbol("*", 8, 5), Symbol("*", 8, 5),
), ),
] }
def test_solve_part_1(): def test_solve_part_1():