import sys,math
input = sys.stdin.readline
def find(x):
if parent[x]!=x:
parent[x] = find(parent[x])
return parent[x]
N,M,Q = map(int,input().split())
graph = [[0,*map(int,input().split())] for i in range(M)]
for i in range(Q):
graph[int(input())-1][0] = Q-i
graph.sort()
parent = [i for i in range(N+1)]; group = [1]*(N+1)
result = 0
for i,x,y in graph:
if find(x)==find(y):
continue
if i:
result += group[parent[x]]*group[parent[y]]
group[parent[x]] += group[parent[y]]
parent[parent[y]] = parent[x]
print(result)
union-find 응용문제
비용 [백준 2463번] 비용 (Python3) (tistory.com) 문제와 풀이가 비슷하다. 하지만 문제가 훨씬 더 직관적이고 아이디어를 떠올리기는 난이도도 매우 쉽다.
풀이를 설명하면,
1. 간선 입력이 x,y로 들어오면 이를 [0,x,y]로 저장한다.
2. i 번째로 제거할 간선의 번호는 Q-i번으로 바꿔준다. (즉 [Q-i,x,y]로 저장)
3. 그래프를 정렬한다.
4. 그래프를 순서대로 union 해주고, 그룹의 노드 수를 저장한다.
5. 간선의 번호가 0인 경우에는 그냥 union한다.
6. 간선의 번호가 1 이상인 경우에는 같은 그룹이 아니라면 x와 y 그룹의 노드수의 곱을 결과에 더한다.
7. 결과값 출력
매우 간단한 문제였다. 나누는 것보다 합치는게 더 쉽다는걸 생각하면 역순으로 풀어야 한다는 것을 쉽게 떠올릴 수 있을 것이다.
오늘의 교훈) 다양한 문제를 풀자