import sys
input = sys.stdin.readline
sys.setrecursionlimit(10**7)
def DFS(x):
diff = []
if not child[x]:
return
for c in child[x]:
DFS(c)
DP[x][0] += max(DP[c])
diff.append(DP[c][0]-max(DP[c])+score[x]*score[c])
DP[x][1] = DP[x][0]+max(diff)
N = int(input())
parent = [0]+[*map(int,input().split())]
score = [0]+[*map(int,input().split())]
child = [[] for i in range(N+1)]
for i in range(N):
child[parent[i]].append(i+1)
DP = [[0,0] for i in range(N+1)]
DFS(0)
print(DP[0][0])
간단한 트리 DP 문제.
우수마을 [백준 1949번] 우수 마을 (Python3) (tistory.com) 과 거의 비슷한 방식으로 해결하였고, 플래티넘 문제인데도 불구하고 골드2인 우수 마을보다 오히려 더 쉽게 해결하였다.
알고리즘을 설명하면,
1. 모든 노드의 자식노드로 트리를 만든다.
2. DP[x]는 현 노드부터 시작해서 자식노드들의 시너지점수의 최댓값을 의미하고, DP[x][0]은 x가 멘토링에 속하지 않을 때의 최댓값, DP[x][1]은 x가 멘토링에 속할 때의 최댓값을 의미한다.
3. DFS 과정에서 DP[x][0]는 자식노드의 최댓값의 총 합을 저장, diff에는 각 자식 노드마다 최댓값과 멘토링에 속하지 않을 때의 점수의 차이 + 현 노드와 멘토링했을 때의 점수를 저장한다.
4. DP[x][1]은 DP[x][0] + max(diff)를 저장한다.
5. DP[0][0]를 출력한다.
우수마을 문제를 풀면서 3번 조건이 자명하다는 사실을 알지 못해서 조금 어렵게 풀었었는데, 그때의 풀이 방식이 이 문제에 더 간단하게 적용 가능해서 한결 더 쉽게 풀 수 있었다.
오늘의 교훈) 다양한 풀이 방법을 고민하자