import sys
input = sys.stdin.readline
def updateseg(a,start,end,x):
if a == end:
seg[x] += 1
return
mid = (start+end)//2
if a <= mid:
updateseg(a,start,mid,x*2)
else:
updateseg(mid,start,mid,x*2)
updateseg(a,mid+1,end,x*2+1)
def sumseg(a,start,end,x):
global SUM
SUM += seg[x]
if start==end:
return
mid = (start+end)//2
if a <= mid:
sumseg(a,start,mid,x*2)
else:
sumseg(a,mid+1,end,x*2+1)
N = int(input())
data = [*map(int,input().split())]
A = []
for i in range(N):
A.append((data[i],i))
A.sort()
seg = [0]*(4*N)
result = 0
for a in range(N):
idx = A[a][1]
SUM = 0
sumseg(idx,0,N-1,1)
if idx+SUM>a:
result += idx+SUM-a
updateseg(idx,0,N-1,1)
print(result)
세그먼트 트리를 사용해서 해결하였다.
버블 소트 [백준 1377번] 버블 소트 (Python3) (tistory.com)와 수열과 쿼리 21 [백준 16975번] 수열과 쿼리 21 (Python3) (tistory.com)에서 사용한 방법을 응용하였다.
버블 소트에서는 수열을 튜플로 저장한 후 sort해서 현재의 위치와 index의 차이를 구하였고, 수열과 쿼리 21에서 구간 전체에 대한 덧셈을 세그먼트 트리로 표현하였는데, 이 두 가지를 이용한다.
그리디하게 생각하면, 가장 작은 수 부터 순서대로 옮기다 보면 자동으로 순서가 맞춰질 것이다.
이때, 옮겨야되는 횟수는 현재의 위치와 원래의 index의 차이이다.
그러나 문제는 이미 앞에서 옮긴 숫자가 존재하면 인덱스가 밀리게 되는데, 이를 세그먼트 트리로 이용하는 것이다.
예시를 들면, 3 2 1이 있다. 이때 1 2 3이 되려면 1은 원래 인덱스와 현 위치의 차이가 2이므로 두번 옮기면 된다. 이때 1 3 2가 되면, 3과 2는 인덱스가 한 칸씩 앞으로 밀려났다. 따라서 0~(1의 원래 인덱스) 까지의 구간의 인덱스는 +1을 해주는 것이다.
따라서 (원래 인덱스 + 밀린 인덱스 - 정렬한 위치) 의 합이 곧 정답이다.
오늘의 교훈) 세그먼트 트리는 유용하다.