처음에는 "뭐야 저번에 풀었던 별자리 만들기랑 똑같은 문제 아니야? 너무 쉬운데?" 라고 생각했다.
그래서 별자리 문제를 풀었던 방식대로 모든 노드 사이에 그래프를 생성하고 프림 알고리즘을 이용해서 풀어주려했다.
코드는 다음과 같다.
import sys
input = sys.stdin.readline
from heapq import heappush, heappop
N = int(input())
xlist,ylist,zlist = [0]*N,[0]*N,[0]*N
for i in range(N):
xlist[i],ylist[i],zlist[i] = map(int,input().split())
graph = [[0]*N for i in range(N)]
for i in range(N):
for j in range(i+1,N):
graph[i][j]=graph[j][i]=min(abs(xlist[i]-xlist[j]),abs(ylist[i]-ylist[j]),abs(zlist[i]-zlist[j]))
cost = 0
check = [0]*N
check[0] = 1
hq = []
for i in range(N):
heappush(hq,(graph[0][i],i))
for _ in range(N-1):
while hq:
c,x = heappop(hq)
if check[x]:
continue
check[x] = 1
cost += c
for i in range(N):
heappush(hq,(graph[x][i],i))
print(cost)
그러나 메모리 초과가 나왔다.
문제를 다시 확인하니 N이 무려 10만까지 갈 수 있다는 사실을 알게 되었다. 그럼 그래프의 개수는 N^2개가 되므로 당연히 메모리초과가 날 수밖에 없었다.
이 문제에서 중요한 포인트는 행성간에 cost가 별자리 만들기와는 다르게 좌표간의 거리가 아닌 x,y,z 중 짧은 거리라는 것이다.
이를 이용해 풀이한 방법은 다음과 같다.
import sys
input = sys.stdin.readline
from heapq import heappush, heappop
N = int(input())
xlist,ylist,zlist = [],[],[]
for num in range(N):
x,y,z = map(int,input().split())
xlist.append((x,num))
ylist.append((y,num))
zlist.append((z,num))
xlist.sort()
ylist.sort()
zlist.sort()
xyz = [xlist,ylist,zlist]
graph = [[] for i in range(N)]
for i in range(3):
for j in range(N-1):
a1,num1 = xyz[i][j]
a2,num2 = xyz[i][j+1]
graph[num1].append((num2,(a2-a1)))
graph[num2].append((num1,(a2-a1)))
cost = 0
check = [0]*N
check[0] = 1
hq = []
for next,c in graph[0]:
heappush(hq,(c,next))
for _ in range(N-1):
while hq:
c,now = heappop(hq)
if check[now]:
continue
check[now] = 1
cost += c
for next,c in graph[now]:
heappush(hq,(c,next))
print(cost)
모든 행성에 대해 x,y,z에 대한 리스트를 만들어준다. 그리고 x,y,z 리스트를 정렬한다.
그리고 모든 행성간에 그래프를 그리는게 아니라 x,y,z에 대해서 이웃한 행성 (최대 6개) 에 대한 그래프를 만들어준다.
그리고 똑같이 프림 알고리즘을 실행하면 된다.
이 아이디어를 떠올리기가 쉽지 않았던 문제였다.
오늘의 교훈) 데이터의 최대치를 잘 확인해서 시간, 공간 복잡도를 신경쓰자