재미있는 문제였다.
이 문제는 특이하게도 N의 범위에 따라서 점수를 따로 측정하였다. easy 버전과 hard 버전이 있는 문제들이 많은데 그걸 한 문제로 합쳐놓은 것 같은 문제였다.
처음에 내가 제출한 코드는 이러했다.
import sys
input = sys.stdin.readline
mod = 1000000007
N = int(input())
food = [*map(int,input().split())]
food.sort()
cal = {i:1 for i in range(N-1)}
for i in range(1,N-1):
cal[i] = cal[i-1]*2%mod
result = 0
for i in range(N):
for j in range(i+1,N):
result += (food[j]-food[i])%mod*cal[j-i-1]
result %= mod
print(result)
시간복잡도가 O(N^2)인 코드로, 그냥 단순하게 이중 for문을 돌리는 코드이다.
그랬더니 small 버전만 통과가 되어 50점을 받았다.
그렇게 공책에 끄적끄적 해보다가 중요한 사실을 발견했는데 식을 두 요소의 차이를 가지고 하는게 아니라 한 요소의 값에 대해서 쓸 수 있다는 것을 알게되었다.
이를 반영한 코드는 다음과 같다.
import sys
input = sys.stdin.readline
mod = 1000000007
N = int(input())
food = [*map(int,input().split())]
food.sort()
cal = {i:1 for i in range(N)}
for i in range(1,N):
cal[i] = cal[i-1]*2%mod
result = 0
for i in range(N):
result -= food[i]%mod * cal[N-1-i]
result += food[i]%mod * cal[i]
result %= mod
print(result)
한 요소가 모든 조합에 미치는 영향은 (2^i)-(2^(N-1-i)) 이라는 사실을 조합을 생각하면 쉽게 구할 수 있는데 이를 이용한 코드이다.
결과값을 내는 과정은 시간복잡도 O(N)으로 구현 가능했고, sort하는 과정이 있으니 전체 시간복잡도는 O(NlogN)일 것이다. 이 코드로 제출하자 250점을 받을 수 있었다.
오늘의 교훈) 시간복잡도를 최대한 줄일 수 있는 방법을 생각하자.