import sys
input = sys.stdin.readline
def findparent(x):
if parent[x]!=x:
parent[x] = findparent(parent[x])
return parent[x]
N,M = map(int,input().split())
graph = sorted([[*map(int,input().split())] for i in range(M)],key=lambda x:-x[2])
parent = [i for i in range(N+1)]; group = [1]*(N+1)
result = 0; connected = 0
for x,y,c in graph:
if findparent(x)!=findparent(y):
connected += group[parent[x]]*group[parent[y]]
group[parent[x]] += group[parent[y]]
parent[parent[y]] = parent[x]
result += c*connected
result %= 10**9
print(result)
union-find 응용문제
아이디어를 떠올리기 까다로운 문제였다. 사실 문제를 이해하는 거부터가 쉽지 않았다. 하지만 구현은 매우 간단하다.
풀이를 설명하면,
1. 그래프를 간선의 가중치를 기준으로 내림차순으로 정렬한다.
2. connected는 연결된 노드쌍의 수를 의미하고 group은 현재 group의 크기를 의미한다.
3. 가중치가 큰 간선부터 탐색한다. 이때 x와 y가 연결되지 않았으면 union 해주고, x의 group과 y의 group의 크기의 곱 만큼 connected에 더해준다.
4. 현재까지 연결된 노드쌍의 수*가중치를 결과값에 더해준다.
5. 결과값을 출력한다.
가중치가 작은 간선부터 제거하므로 역으로 가중치가 큰 간선부터 union해주면서 현재까지 연결된 쌍의 개수만큼 더해주는것이 포인트이다.
오늘의 교훈) union-find를 적재적소에 활용하자