From f78fa5d772da2ba6440d7ab62a311537084d5c22 Mon Sep 17 00:00:00 2001 From: Marijn Doeve Date: Thu, 16 Dec 2021 19:52:17 +0100 Subject: [PATCH] Optimize 16 --- 16/16.py | 42 ++++++++++++++++-------------------------- 1 file changed, 16 insertions(+), 26 deletions(-) diff --git a/16/16.py b/16/16.py index 4fe51eb..a0e3b2d 100644 --- a/16/16.py +++ b/16/16.py @@ -1,8 +1,11 @@ import sys from math import prod +from typing import Callable + +operator_func = Callable[[list[int]], int] LITERAL = 4 -OPERATOR = { +OPERATOR: dict[int, operator_func] = { 0: sum, 1: prod, 2: min, @@ -15,12 +18,12 @@ OPERATOR = { total_version = 0 -def read_bits(bits, start, n) -> tuple[str, int]: +def read_bits(bits: str, start: int, n: int) -> tuple[str, int]: end = start + n return bits[start:end], end -def read_bits_int(bits, start, n) -> tuple[int, int]: +def read_bits_int(bits: str, start: int, n: int) -> tuple[int, int]: number_bits, pos = read_bits(bits, start, n) return int(number_bits, base=2), pos @@ -33,33 +36,27 @@ def hex_string_to_bit_string(hex_string: str) -> str: return "".join([hex_to_bits(x) for x in list(hex_string)]) -def parse_packet(bits) -> int: +def parse_packet(bits: str) -> tuple[int, int]: pos = 0 - packet_version, _ = read_bits_int(bits, pos, 3) + packet_version, pos = read_bits_int(bits, pos, 3) global total_version total_version += packet_version - type_ID, _ = read_bits_int(bits[3:], pos, 3) + type_ID, pos = read_bits_int(bits, pos, 3) if type_ID == LITERAL: - # print("Found: literal") value, next_pos = parse_literal(bits[pos:]) - # print("literal:", value) else: - # print("found: operator") - value, next_pos = parse_operator(bits[pos:]) + value, next_pos = parse_operator(bits[pos:], type_ID) pos += next_pos return value, pos -def parse_literal(bits) -> tuple[int, int]: +def parse_literal(bits: str) -> tuple[int, int]: pos = 0 - packet_version, pos = read_bits_int(bits, pos, 3) - - type_ID, pos = read_bits_int(bits, pos, 3) number = "" @@ -76,33 +73,26 @@ def parse_literal(bits) -> tuple[int, int]: return int(number, base=2), pos -def parse_operator(bits) -> int: - pos = 0 - packet_version, pos = read_bits_int(bits, pos, 3) - type_ID, pos = read_bits_int(bits, pos, 3) +def parse_operator(bits: str, type_ID: int) -> tuple[int, int]: + length_type_id = bits[0] - length_type_id = bits[6] - - pos = 7 + pos = 1 parts = [] if length_type_id == "0": total_length, pos = read_bits_int(bits, pos, 15) - # print("total length:", total_length) read = 0 while read != total_length: value, just_read = parse_packet(bits[pos:]) read += just_read - # print(f"Length was: {just_read}, total: {read} of {total_length}") pos += just_read parts.append(value) else: number_of_subs, pos = read_bits_int(bits, pos, 11) - # print("number of subs:", number_of_subs) - for i in range(number_of_subs): + for _ in range(number_of_subs): value, next_pos = parse_packet(bits[pos:]) parts.append(value) pos += next_pos @@ -116,7 +106,7 @@ if __name__ == "__main__": hex_string = sys.stdin.readline().strip() bits = hex_string_to_bit_string(hex_string) - value, _ = parse_operator(bits) + value, _ = parse_packet(bits) print("part one:", total_version) print("part two:", value)