99 lines
2.1 KiB
Python
99 lines
2.1 KiB
Python
import sys
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
from collections import defaultdict
|
|
from typing import Optional
|
|
|
|
|
|
def indeces(y, x):
|
|
return [
|
|
(y - 1, x), # N
|
|
(y, x + 1), # E
|
|
(y + 1, x), # S
|
|
(y, x - 1), # W
|
|
]
|
|
|
|
|
|
def is_low_point(grid: np.ndarray, y: int, x: int) -> bool:
|
|
current = grid[y, x]
|
|
|
|
for index in indeces(y, x):
|
|
try:
|
|
if index[0] < 0 or index[1] < 0:
|
|
raise IndexError
|
|
if grid[index] <= current:
|
|
return False
|
|
except IndexError:
|
|
pass
|
|
|
|
return True
|
|
|
|
|
|
def lowest_neighbour(grid: np.ndarray, y: int, x: int) -> Optional[tuple[int, int]]:
|
|
lowest = grid[y, x]
|
|
point = None
|
|
|
|
for index in indeces(y, x):
|
|
try:
|
|
if index[0] < 0 or index[1] < 0:
|
|
raise IndexError
|
|
|
|
if (new := grid[index]) < lowest:
|
|
lowest = new
|
|
point = index
|
|
except IndexError:
|
|
pass
|
|
|
|
return point
|
|
|
|
|
|
def heatmap(grid):
|
|
plt.imshow(grid, cmap="Greys", interpolation="nearest")
|
|
# plt.savefig("heat.png")
|
|
plt.show()
|
|
|
|
|
|
def part_one(grid: np.ndarray):
|
|
low_points = np.empty_like(grid, dtype=bool)
|
|
|
|
for point in np.ndindex(grid.shape):
|
|
low_points[point] = is_low_point(grid, *point)
|
|
|
|
return sum(grid[low_points] + 1)
|
|
|
|
|
|
def part_two(grid: np.ndarray):
|
|
smoke_direction = np.empty_like(grid, dtype=tuple)
|
|
|
|
for point in np.ndindex(grid.shape):
|
|
smoke_direction[point] = lowest_neighbour(grid, *point)
|
|
|
|
basins = defaultdict(int)
|
|
|
|
for starting_point, height in np.ndenumerate(grid):
|
|
if height == 9:
|
|
continue
|
|
|
|
point = starting_point
|
|
while (next := smoke_direction[point]) is not None:
|
|
point = next
|
|
|
|
basins[point] += 1
|
|
|
|
return np.prod(sorted(basins.values(), reverse=True)[:3])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
grid_lists = []
|
|
|
|
for line in sys.stdin.readlines():
|
|
grid_lists.append([int(x) for x in list(line.strip())])
|
|
|
|
grid = np.array(grid_lists)
|
|
|
|
print("Part One:", part_one(grid))
|
|
print("Part Two:", part_two(grid))
|