97 lines
2.6 KiB
Python
97 lines
2.6 KiB
Python
import numpy as np
|
|
from pprint import pprint
|
|
import sys
|
|
|
|
from numpy.lib.function_base import delete
|
|
|
|
EASY_LENGTHS = np.array([2, 4, 3, 7])
|
|
|
|
vec_len = np.vectorize(len)
|
|
|
|
|
|
def part_one(lines):
|
|
|
|
total = 0
|
|
|
|
for line in lines:
|
|
digits, output_str = line.split("|")
|
|
digits = digits.split()
|
|
output_value: np.ndarray = output_str.split()
|
|
lengths: np.ndarray = vec_len(output_value)
|
|
for length in lengths:
|
|
if length in EASY_LENGTHS:
|
|
total += 1
|
|
|
|
return total
|
|
|
|
|
|
def part_two(lines):
|
|
total = 0
|
|
|
|
for line in lines:
|
|
digits, output_str = line.split("|")
|
|
digits = np.array([frozenset(list(x)) for x in digits.split()])
|
|
|
|
output_value: np.ndarray = np.array(
|
|
[frozenset(list(x)) for x in output_str.split()]
|
|
)
|
|
|
|
lengths: np.ndarray = vec_len(digits)
|
|
|
|
mapping = {}
|
|
|
|
# Get easy mappings
|
|
mapping[1] = digits[lengths == 2][0]
|
|
mapping[4] = digits[lengths == 4][0]
|
|
mapping[7] = digits[lengths == 3][0]
|
|
mapping[8] = digits[lengths == 7][0]
|
|
|
|
digits = digits[digits != mapping[1]]
|
|
digits = digits[digits != mapping[4]]
|
|
digits = digits[digits != mapping[7]]
|
|
digits = digits[digits != mapping[8]]
|
|
lengths: np.ndarray = vec_len(digits)
|
|
|
|
# 3: length 5, superset of 7
|
|
mapping[3] = digits[(lengths == 5) & (mapping[7] <= digits)][0]
|
|
digits = digits[digits != mapping[3]]
|
|
lengths: np.ndarray = vec_len(digits)
|
|
|
|
# 5: length 5, 3 parts of 4
|
|
digits_5 = digits[lengths == 5]
|
|
mapping[5] = digits_5[vec_len(digits_5 & mapping[4]) == 3][0]
|
|
digits = digits[digits != mapping[5]]
|
|
lengths: np.ndarray = vec_len(digits)
|
|
|
|
# 2: Remaining length 5
|
|
mapping[2] = digits[lengths == 5][0]
|
|
digits = digits[digits != mapping[2]]
|
|
del lengths # make sure the invalid list isn't used anymore
|
|
|
|
# -- Only length 6 left
|
|
|
|
# 9: superset of 4
|
|
mapping[9] = digits[mapping[4] <= digits][0]
|
|
digits = digits[digits != mapping[9]]
|
|
|
|
# 0: superset of 7
|
|
mapping[0] = digits[mapping[7] <= digits][0]
|
|
digits = digits[digits != mapping[0]]
|
|
|
|
assert len(digits) == 1
|
|
|
|
# 6: remaining
|
|
mapping[6] = digits[0]
|
|
|
|
decode_mapping = {v: str(k) for k, v in mapping.items()}
|
|
|
|
total += int("".join([decode_mapping[x] for x in output_value]))
|
|
|
|
return total
|
|
|
|
|
|
if __name__ == "__main__":
|
|
lines = sys.stdin.readlines()
|
|
print("Part one:", part_one(lines))
|
|
print("Part two:", part_two(lines))
|