처음에는 union-find 문제라고 생각했다.
union-find를 이용해서 cycle과 cycle이 아닌 것을 찾고, cycle은 cycle의 크기를 기록, cycle이 아니면 cycle까지의 거리를 저장하고, 주어지는 n에 대해서 cycle의 크기로 나눈 나머지만큼 함수를 실행하면 된다고 생각했다.
코드는 다음과 같다.
import sys
input = sys.stdin.readline
def findparent(x):
if parent[x] == x:
return x
parent[x] = findparent(parent[x])
return parent[x]
def DFS(x):
if visited[x]:
cyclecount(x)
return
visited[x] = 1
DFS(f[x])
if degree[x] != -1:
return
findparent(x)
degree[x] = degree[f[x]]+1
def cyclecount(x):
cnt = 1
x1 = f[x]
while x1 != x:
x1 = f[x1]
cnt += 1
while not cycle[x]:
cycle[x] = cnt
degree[x] = 0
parent[x] = x
x = f[x]
N = int(input())
data = [*map(int,input().split())]
f = {}
for i in range(1,N+1):
f[i] = data[i-1]
parent = {i:f[i] for i in range(1,N+1)}
degree = {i:-1 for i in range(1,N+1)}
cycle = {i:0 for i in range(1,N+1)}
visited = [0]*(N+1)
for i in range(1,N+1):
if visited[i]:
continue
DFS(i)
Q = int(input())
for i in range(Q):
n,x = map(int,input().split())
if n >= degree[x]:
n -= degree[x]
x = parent[x]
for _ in range(n%cycle[x]):
x = f[x]
else:
for _ in range(n):
x = f[x]
print(x)
DFS로 탐사하면서 이미 방문한 곳을 재방문시 cycle이라고 판단하고 cycle함수로 cycle의 크기를 기록, 나머지 cycle 밖의 노드에 대해서는 cycle까지의 거리와 가장 가까운 cycle의 node를 parent로 기록하는 코드이다.
그러나 이 코드는 시간초과가 발생하였다.
아무리 n을 cycle의 크기의 나머지로 크기를 줄여줬다지만 만약 cycle의 크기가 최대 20만으로 매우 크다면 나머지 또한 매우 커져 결국 시간복잡도가 O(N^2)이 되는 것이 문제로 보였다.
이 문제는 union-find 문제가 아니었다.
sparse table이라는 새로운 알고리즘을 소개하는 문제였던 것이다.
import sys
input = sys.stdin.readline
N = int(input())
f = [[0]*(N+1) for i in range(20)]
data = [*map(int,input().split())]
for i in range(N):
f[0][i+1] = data[i]
for i in range(1,20):
for n in range(1,N+1):
f[i][n] = f[i-1][f[i-1][n]]
Q = int(input())
for _ in range(Q):
n,x = map(int,input().split())
for i in range(20):
if n&(1<<i):
x = f[i][x]
print(x)
sparce-table에 대한 설명은 https://namnamseo.tistory.com/entry/Sparse-Table 여기를 참고하였다.
위 블로그의 설명을 토대로 코드를 짜봤더니 코드도 훨씬 간결하고 답도 빠른 시간내에 낼 수 있었다.
오늘의 교훈) sparse-table에 대해서 알아보자.