import sys
input = sys.stdin.readline
N,S = map(int,input().split())
seq = list(map(int,input().split()))
SUM = [0]*(N+1)
result = N+1
for i in range(1,N+1):
SUM[i] = SUM[i-1]+seq[i-1]
if SUM[i] >= S:
for j in range(1,i+1):
if j >= result:
break
if SUM[i]-SUM[i-j] >= S:
result = min(result,j)
break
if result == N+1:
print(0)
else:
print(result)
위 코드는 수열의 0~n까지의 합을 SUM 리스트에 저장하고, SUM 리스트의 차를 이용해서 부분합을 구하는 코드이다.
시간초과 문제를 해결하기 위해서 for문을 뒤에서부터 돌리고, 최근의 결과값보다 커지거나 결과값을 갱신하는 경우 break를 하였다.
하지만 pypy3에서는 통과되었는데, python3에서는 시간초과로 통과되지 못하였다. 알고리즘 자체가 시간복잡도가 O(N2)이다 보니 아무리 for문을 제한해줘도 한계가 있는 것 같았다.
그러다가 이 문제를 투 포인터 알고리즘으로 해결해야 한다는 것을 알게되었고, 이를 구현해보았다.
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**6)
N,S = map(int,input().split())
seq = list(map(int,input().split()))
s = 0
result = N+1
def SUM(start,end):
global s,result
while s < S:
if end == len(seq):
return
s += seq[end]
end += 1
while s >= S:
s -= seq[start]
start += 1
result = min(end-start+1,result)
SUM(start,end)
SUM(0,0)
if result == N+1:
print(0)
else:
print(result)
재귀함수를 통해서 구현하였고, s값이 S보다 작으면 end지점을 1씩 늘려가며 합을 구하고, S보다 커지면 start지점을 1씩 늘려가며 빼주는 방식이다.
이 방식을 사용하니 pypy3로 3360ms가 걸리던 것이 무려 132ms로 30배가량 단축되었다.
오늘의 교훈) 투 포인트 알고리즘을 고려하자