This commit is contained in:
2021-12-09 20:28:28 +01:00
parent 0298d84eca
commit 31eb702de7

76
09/9.py
View File

@@ -1,25 +1,27 @@
from pprint import pprint
import numpy as np
import sys
import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(
threshold=sys.maxsize,
# linewidth=1000,
)
from collections import defaultdict
from typing import Optional
def is_low_point(grid: np.ndarray, y: int, x: int) -> bool:
current = grid[y, x]
indexies = [
def indeces(y, x):
return [
(y - 1, x), # N
(y, x + 1), # E
(y + 1, x), # S
(y, x - 1), # W
]
for index in indexies:
def is_low_point(grid: np.ndarray, y: int, x: int) -> bool:
current = grid[y, x]
for index in indeces(y, x):
try:
if index[0] < 0 or index[1] < 0:
raise IndexError
if grid[index] <= current:
return False
except IndexError:
@@ -28,19 +30,60 @@ def is_low_point(grid: np.ndarray, y: int, x: int) -> bool:
return True
def lowest_neighbour(grid: np.ndarray, y: int, x: int) -> Optional[tuple[int, int]]:
lowest = grid[y, x]
point = None
for index in indeces(y, x):
try:
if index[0] < 0 or index[1] < 0:
raise IndexError
if (new := grid[index]) < lowest:
lowest = new
point = index
except IndexError:
pass
return point
def heatmap(grid):
plt.imshow(grid, cmap="Greys", interpolation="nearest")
# plt.savefig("heat.png")
plt.show()
def part_one(grid: np.ndarray):
low_points = np.empty_like(grid, dtype=bool)
for point in np.ndindex(grid.shape):
low_points[point] = is_low_point(grid, *point)
# print(str(grid).replace(" ", ""))
# print(str(low_points).replace(" True", "o").replace("False", ".").replace(" ", ""))
# print(grid[low_points] + 1)
return sum(grid[low_points] + 1)
def part_two(grid: np.ndarray):
smoke_direction = np.empty_like(grid, dtype=tuple)
for point in np.ndindex(grid.shape):
smoke_direction[point] = lowest_neighbour(grid, *point)
basins = defaultdict(int)
for starting_point, height in np.ndenumerate(grid):
if height == 9:
continue
point = starting_point
while (next := smoke_direction[point]) is not None:
point = next
basins[point] += 1
return np.prod(sorted(basins.values(), reverse=True)[:3])
if __name__ == "__main__":
import sys
@@ -49,4 +92,7 @@ if __name__ == "__main__":
for line in sys.stdin.readlines():
grid_lists.append([int(x) for x in list(line.strip())])
grid = np.array(grid_lists)
print("Part One:", part_one(grid))
print("Part Two:", part_two(grid))