from typing import Iterator test_input = """ 47|53 97|13 97|61 97|47 75|29 61|13 75|53 29|13 97|29 53|29 61|53 97|53 61|29 47|13 75|47 97|75 47|61 75|61 47|29 75|13 53|13 75,47,61,53,29 97,61,53,29,13 75,29,13 75,97,47,61,53 61,13,29 97,13,75,29,47 """.strip() test_solution_p1 = 143 test_solution_p2 = 123 def solve_p1(puzzle_input: str) -> int: rules, updates = _parse_rules_and_updates(puzzle_input) return sum( update[len(update) // 2] for update in updates if _is_correctly_ordered(rules, update) ) def solve_p2(puzzle_input: str) -> int: rules, updates = _parse_rules_and_updates(puzzle_input) corrected_updates = ( _fix_update(rules, update) for update in updates if not _is_correctly_ordered(rules, update) ) return sum(update[len(update) // 2] for update in corrected_updates) OrderingRules = dict[int, set[int]] Update = tuple[int, ...] def _parse_rules_and_updates( puzzle_input: str, ) -> tuple[OrderingRules, Iterator[Update]]: ordering_rules, _, updates = puzzle_input.partition("\n\n") rules: OrderingRules = {} for line in ordering_rules.strip().split("\n"): a, _, b = line.partition("|") a, b = int(a), int(b) if b not in rules: rules[b] = set() rules[b].add(a) return ( rules, (tuple(map(int, line.split(","))) for line in updates.strip().split("\n")), ) def _is_correctly_ordered(rules: OrderingRules, update: Update) -> bool: contains = set(update) seen = set() for i in update: for j in rules.get(i, ()): if j in contains and j not in seen: return False seen.add(i) return True def _fix_update(rules: OrderingRules, update: Update) -> Update: contains = set(update) seen = set() fixed_update = list() def _fix_item(i): for j in rules.get(i, ()): if j in contains and j not in seen: _fix_item(j) fixed_update.append(i) seen.add(i) for i in update: if i not in seen: _fix_item(i) return tuple(fixed_update)