diff --git a/day_05/program.py b/day_05/program.py index 7000ac3..a2e87fa 100755 --- a/day_05/program.py +++ b/day_05/program.py @@ -26,20 +26,37 @@ def check_validity(rules: dict, update: list) -> bool: for position, page_number in enumerate(update): if page_number not in rules: continue - # if a number that is supposed to come earlier than our current page_number - # is included in the rest of the update, it breaks the rules for earlier_number in rules[page_number]: if earlier_number in update[position:]: return False return True -def get_middle_page_sums(rules: dict, updates: list) -> int: - middle_page_sum = 0 +def sort_update(rules: dict, update: list) -> None: + len_update = len(update) + for i in range(len_update - 1): + for j in range(len_update - i - 1): + if update[j] not in rules: + continue + if update[j + 1] in rules[update[j]]: + update[j], update[j + 1] = update[j + 1], update[j] + + +def get_middle_page(update: list) -> int: + middle_pos = len(update) // 2 + return int(update[middle_pos]) + + +def get_middle_page_sums(rules: dict, updates: list) -> tuple: + valid_middle_page_sum = 0 + corrected_middle_page_sum = 0 for update in updates: if check_validity(rules, update): - middle_page_sum += int(update[int(len(update) / 2)]) - return middle_page_sum + valid_middle_page_sum += get_middle_page(update) + else: + sort_update(rules, update) + corrected_middle_page_sum += get_middle_page(update) + return valid_middle_page_sum, corrected_middle_page_sum def main(): @@ -47,7 +64,9 @@ def main(): lines = get_lines("input.txt") rules, updates = get_rules_and_updates(lines) - print("Part 1: The sum of valid middle pages is:", get_middle_page_sums(rules, updates)) + valid_middle_page_sum, invalid_middle_page_sum = get_middle_page_sums(rules, updates) + print("Part 1: The sum of middle pages in valid updates is:", valid_middle_page_sum) + print("Part 2: The sum of middle pages in corrected updates is:", invalid_middle_page_sum) if __name__ == '__main__':