import sys
input = sys.stdin.readline
from bisect import bisect_right
def group(last,maxsum,cnt,groupcnt):
global result1,result2
if maxsum >= result1:
return
if cnt == M-1:
maxsum = max(maxsum,sumlist[-1]-sumlist[last])
if maxsum >= result1:
return
result1 = maxsum
result2 = groupcnt + " " + str(N-last)
return
if M-cnt == N-last:
maxsum = max(max(data[last:]),maxsum)
if maxsum >= result1:
return
result1 = maxsum
result2 = groupcnt + " 1"*(M-cnt)
return
idx = bisect_right(sumlist,sumlist[last]+maxsum)
if last != idx-1:
group(idx-1,maxsum,cnt+1,groupcnt+" "+str(idx-1-last))
for i in range(idx,N):
group(i,sumlist[i]-sumlist[last],cnt+1,groupcnt+" "+str(i-last))
N,M = map(int,input().split())
data = [*map(int,input().split())]
sumlist=[0]
for i in range(N):
sumlist.append(sumlist[-1]+data[i])
result1 = 10**6
group(0,sumlist[-1]//M,0,"")
print(result1)
print(result2.strip())
그리디 + DFS로 해결하였다.
알고리즘을 요약하면,
1. 숫자구슬에 대한 누적합 리스트를 만든다.
2. group 함수에서 last는 이전에 그룹을 나눈 좌표, maxsum은 지금까지의 그룹합 중 가장 큰 값, cnt는 그룹의 개수, groupcnt는 각 그룹의 개수를 저장하는 문자열데이터이다.
3. 현 위치에서 그룹을 만들 때 maxsum보다 작으면서 가장 큰 좌표를 bisect로 찾는다.
4. 찾은 좌표가 last보다 큰 경우, 함수 실행
5. 찾은 좌표 이후의 모든 좌표에 대해서 함수 실행
6. cnt가 M-1인 경우 나머지 모든 숫자를 그룹으로 잡고 결과값을 갱신한다.
7. 개수가 0인 그룹이 존재해선 안되므로 M-cnt = N-last면 나머지 그룹을 전부 크기가 1인 그룹이라 판단한다.
8. 결과 출력
각 그룹을 구성하는 구슬의 개수를 출력하는 방식을 문자열로 저장한다는 좀 아름답지 못한 방법을 사용했는데, 역추적을 좀 더 깔끔하게 할 수 있는 방법을 생각해봐야겠다.
오늘의 교훈) 그리디는 어렵다.