import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**5)
def BFS():
dq = [1]
while dq:
now = dq.pop()
for next in graph[now]:
if not child[next]:
parent[next] = now
child[now].append(next)
dq.append(next)
def DFS(now):
childDP[now][0] = 1
for next in child[now]:
DFS(next)
for k in range(K):
childDP[now][k+1] += childDP[next][k]
def count(now):
DP[now][0] = 1
for k in range(K):
DP[now][k+1] = childDP[now][k+1] + DP[parent[now]][k]
if k and now!=1:
DP[now][k+1] -= childDP[now][k-1]
for next in child[now]:
count(next)
N,K = map(int,input().split())
graph = [[] for i in range(N+1)]
for _ in range(N-1):
a,b = map(int,input().split())
graph[a].append(b); graph[b].append(a)
parent = [0]*(N+1); child = [[] for i in range(N+1)]
BFS()
childDP,DP = [[[0]*(K+1) for i in range(N+1)] for i in range(2)]
DFS(1); count(1)
print(max(map(sum,DP)))
트리 응용 문제
로스팅하는 엠마는 바리스타입니다 [백준 15647번] 로스팅하는 엠마도 바리스타입니다 (Python3) (tistory.com) 와 풀이가 거의 비슷하다.
과정을 설명하면,
1. BFS로 트리구조를 만든다. (1을 루트노드로 각 노드의 부모노드, 자식노드 저장)
2. DP를 N*K 크기의 배열로 만든다. childDP[n][k]는 현재 노드에서 거리가 k인 자식노드의 수를 의미하고, DP는 거리가 k인 노드의 수를 의미한다.
3. DFS를 이용해서 각 노드에서 0~K 거리까지의 자식노드의 개수를 childDP에 저장한다. 자식노드와 부모노드의 거리가 1이므로, 자식노드에서 k거리의 노드 수는 곧 현재 노드의 k+1 거리의 노드 수이다.
4. count 함수를 이용해서 각 노드에서 0~K 거리까지의 노드 (자식, 부모 등 모두 포함) 개수를 DP에 저장한다.
5. 0~K 거리까지의 노드 수의 합이 가장 큰 경우의 값을 출력한다.
트리+DP 문제는 항상 재미있는 것 같다.
오늘의 교훈) 다양한 문제를 해결하자