본문 바로가기
알고리즘 문제 풀이

[백준 7812] 중앙 트리(파이썬)

by 다빈치코딩 2023. 8. 28.

목차

    반응형

    문제 출처 : https://www.acmicpc.net/problem/7812

     

    7812번: 중앙 트리

    입력은 여러 개의 테스트 케이스로 이루어져 있다. 각 테스트 케이스의 첫 줄에는 트리의 정점의 수 n이 주어진다. (1 ≤ n ≤ 10,000) 각 정점은 0번부터 n-1번까지 번호가 붙여져 있다. 다음 n-1개 줄

    www.acmicpc.net

    트리 DP의 문제 입니다. 트리와 DP를 합친 문제로 복잡하게 느껴질 수 있지만 하나하나 해결해나가면 문제를 풀 수 있습니다.

    이 문제를 풀기 위해서는 모든 정점들의 거리를 계산해야 합니다. 중앙 정점이 어디가 될지 모르기 때문입니다.

    무작정 문제를 푼다면 다음과 같이 진행 하면 됩니다. 먼저 A를 중앙 정점으로하여 모든 정점까지의 거리를 구해줍니다.

    A-B = 2, A-C = 3, A-D = 9, A-E = 14

    이 값들을 모두 더하면 28이 나옵니다. 이런 방식으로 B, C, D, E를 중앙 정점으로하여 모든 정점까지의 거리를 구합니다. 그리고 최종적으로 B를 중앙 정점으로 하였을 때 최소값인 22를 얻을 수 있습니다. 이렇게 문제를 풀게되면 시간초과가 날 수밖에 없습니다. A를 중앙 정점으로 계산하고, 또 B로 계산해보고, C로도 계산하면서 정점의 갯수가 늘어나면 날수록 계산량이 늘어나기 때문입니다. dfs의 시간 복잡도는 O(N)이기 때문에 모든 정점에 대해 dfs를 돌린다면 $O(N^2)$의 시간복잡도를 가지게 되고, 결국 시간초과가 될 수 밖에 없습니다. 중복된 연산을 피하기 위해서 어떻게 해야 하는지 생각해 보겠습니다.

     

    A를 중앙 정점으로 하고 정점들의 합을 하나하나 분해해서 생각해 보겠습니다.

    A-B로 가는 경로가 4번, B-C로 가는 경로 1번, B-D로 가는 경로 2번, D-E로 가는 경로 1번 으로 나타낼 수 있습니다.

    B를 중앙 정점으로 하고 정점들의 합을 위와 같이 나타내 보겠습니다.

    A-B로 가는 경로 1번, B-C로 가는 경로 1번, B-D로 가는 경로 2번, D-E로 가는 경로 1번 입니다.

    A를 중앙 정점으로 했을 때 A-B로 가는 경로가 4번이었는데, B를 중앙 정점으로 했을 때는 A-B로 가는 경로가 1번으로 바뀐것 빼고 모두 같습니다. A에서 B로 중앙 정점이 이동할 때 변하는 값은 간선의 cost를 몇 번 곱해주느냐 차이밖에 없습니다.

    먼저 A를 기준으로 모든 자식 노드의 갯수를 구해보겠습니다.

    A는 A, B, C, D, E를 모두 자식 노드로 두고 있기 때문에 자식 노드의 갯수는 5가 됩니다. B는 A를 제외하고 모두 자식 노드로 두고 있기 때문에 4 입니다. C는 자식이 자신밖에 없으므로 1입니다. 이렇게 모든 노드들을 대상으로 자식 노드의 갯수를 구할 수 있습니다.

    다음으로 A를 중점으로 하는 거리의 합을 구해보겠습니다.

    C와 E는 자식노드가 자신밖에 없기 때문에 0 입니다. D는 E방향으로 1개의 노드가 있고, 가중치가 5 입니다. 따라서 D의 값은 5 * 1로 5가 됩니다.

    B는 양쪽으로 노드가 있습니다. C쪽으로는 노드가 1개, cost가 1입니다. 따라서 1 * 1 로 1이 됩니다. D쪽으로는 노드가 2개, 가중치가 7로 7 * 2로 14가 됩니다. 여기에 D의 거리값 5를 가지고 있습니다. C방향의 1, D방향의 14, D의 거리값 5를 합쳐 20이 됩니다.

    A는 노드 4개와 가중치 2를 계산하여 4 * 2로 8이 됩니다. 여기에 B의 거리값 20을 합쳐 28이 됩니다. 이것으로 A를 중앙 정점으로 하는 거리 계산이 끝났습니다. 여기서 따로 기억해야 하는 숫자는 A의 합계인 28만 기억하면 됩니다. 다른 값들은 또다시 변경될 예정이기에 기억할 필요가 없습니다. 이제 다른 정점으로 이동하면 거리값을 업데이트 해보겠습니다.

    A에서 B로 중앙 정점을 이동시킬 때 어떻게 되는지 생각해 보겠습니다. A에서 나머지 값은 그대로 있고 A-B로 연결된 간선의 가중치값만 변한다고 하였습니다. A-B의 연결된 간선은 총 4번으로 가중치 2 이기 때문에 4 * 2만큼을 A의 거리값에서 빼줘야 합니다. 그럼 28 - 8로 20이 됩니다.

    다음으로 B가 중앙 정점일 때 A 방향으로 몇 번 가중치를 더해주어야 하나 생각해보면 됩니다. 전체 노드의 갯수 5개에서 현재 B의 자식 노드 4개를 빼면 1개의 노드만 A의 방향쪽에 있습니다. 따라서 2 * 1 만큼 더해줘야 하고 B는 20 + 2로 22의 가중치를 가지게 됩니다. 그럼 total의 값은 아래처럼 변하게 됩니다.

    다음에는 C로 중앙을 옮겼을 때를 생각해 보겠습니다. C 방향으로 향하는 노드의 갯수는 1개로 22에서 1 * 1을 뺀 21이 됩니다. 반대로 B방향의 노드는 5 - 1로 4개의 노드가 있습니다. 따라서 21 + 4 * 1로 25가 됩니다.

    D가 중앙 정점이 되면 C와 마찬가지로 부모가 B입니다. D방향으로 2개의 노드가 있기 때문에 22 - 2 * 7 로 8이 됩니다. B 방향으로는 5 -2 로 3개의 노드가 있기 때문에 8 + 3 * 7 로 29가 됩니다.

    마지막으로 E가 중앙 정점이 된다면 자식노드가 1개 이므로 1개의 노드값을 빼야 합니다. 29 - 1 * 5로 24가 됩니다. 여기에 E에서 나아가는 정점의 갯수를 더해줍니다. 5 - 1로 4개이고, 가중치 5를 곱하면 5 * 4로 20이 됩니다. 여기에 방금 D의 값 24가 더해져 44가 됩니다.

    이것으로 거리의 합이 가장 작은것은 B가 중앙 정점이 되었을 때의 값인 22가 답이 됩니다.

    코드 작성

    그럼 해당 내용을 바탕으로 코드를 직접 작성해 보겠습니다.

    입력 받기

    import sys
    input = sys.stdin.readline
    sys.setrecursionlimit(10**4)
    

    먼저 입력이 많이 있기 때문에 readline을 통해 속도를 높여줍니다. 다음으로 재귀함수를 여러번 사용하기 때문에 제한값을 10 ** 4 정도로 풀어줍니다. 너무 크게 잡으면 메모리 초과가 발생합니다.

    while True:
        N = int(input())
        if N == 0:
            break
    
        arr = [[] for _ in range(N)]
        for _ in range(N-1):
            a, b, w = map(int, input().split())
            arr[a].append((b, w))
            arr[b].append((a, w))
    

    다음으로 정점과 가중치에 대한 정보를 받습니다. 그리고 그 정보를 바탕으로 그래프를 그려줍니다.

    	cnt = [1] * N
    	dp = [0] * N
    	visited = [False] * N
    	dfs_sum(0)

    다음으로 초기화를 해줍니다. cnt 리스트는 자식노드의 갯수를 나타냅니다. 모든 노드는 자기 자신을 자식노드로 두기 때문에 cnt의 초기값은 1입니다. dp는 합계를 계산할 리스트 입니다. 아직 아무 계산을 하지 않았기 때문에 0으로 합니다. 마지막으로 visited는 그래프 탐색의 필수정보인 방문 기록 입니다. 이미 방문한 곳을 또 방문하지 않게 하기위해서 넣었습니다.

    DFS 계산 함수 만들기

    그럼 dfs_sum 함수를 만들어 보겠습니다. dfs_sum은 0번 노드를 중앙 정점으로 하여 모든 자식노드들의 갯수를 구하고, A를 중앙 정점으로 하는 거리값을 얻을 수 있습니다.

    def dfs_sum(curr):
        visited[curr] = True
        for nxt, cost in arr[curr]:
            if visited[nxt]:
                continue
            dfs_sum(nxt)
            cnt[curr] += cnt[nxt]
            dp[curr] += cost * cnt[nxt] + dp[nxt]

    먼저 재귀 함수를 통해 dfs_sum 함수를 끝까지 호출합니다. 그럼 제일 마지막 노드에 도착해야 다음 로직으로 넘어가게 됩니다. 가장 마지막 노드에 도착해서 노드의 갯수를 더해줍니다. cnt 리스트에는 초깃값 1이 들어있기 때문에 상위에서 하위값을 더해주면 자연스럽게 자식 노드의 갯수를 구할 수 있습니다.

    다음으로 구해진 노드의 갯수에 가중치를 곱해서 자식노드의 dp값에 더해줍니다. 이렇게하면 0번노드를 중앙 정점으로 하는 거리값을 구해주게 됩니다.

    최종값 구하기

        visited = [False] * N
        dfs(0)
    
        print(min(dp))

    A를 중앙 정점으로 계산한 내용을 바탕으로 노드의 중앙 정점을 바꿔주면서 계산값을 구해줍니다. 먼저 dfs를 또 돌려줄 것이기 때문에 visited 리스트를 초기화 합니다. 다음으로 dfs 함수를 실행해서 나온 dp 리스트의 최소값을 출력해주면 정답을 얻을 수 있습니다.

    두 번째 dfs 함수

    def dfs(curr):
        visited[curr] = True
        for nxt, cost in arr[curr]:
            if visited[nxt]:
                continue
            dp[nxt] = dp[curr] - cnt[nxt] * cost + (N - cnt[nxt]) * cost
            dfs(nxt)

    두 번째 dfs 함수는 중앙 정점을 옮겨가면서 옮긴 중앙 정점의 거리값을 구해줍니다. A에서 B로 옮겨서 중앙 정점을 구하고 dfs함수로 다음 정점으로 넘어갑니다. 따라서 dfs_sum 함수와는 다르게 재귀 전에 계산을 하게 됩니다.

    전체 코드

    전체 코드를 살펴보겠습니다.

    import sys
    input = sys.stdin.readline
    sys.setrecursionlimit(10**4)
    
    def dfs_sum(curr):
        visited[curr] = True
        for nxt, cost in arr[curr]:
            if visited[nxt]:
                continue
            dfs_sum(nxt)
            cnt[curr] += cnt[nxt]
            dp[curr] += cost * cnt[nxt] + dp[nxt]
    
    def dfs(curr):
        visited[curr] = True
        for nxt, cost in arr[curr]:
            if visited[nxt]:
                continue
            dp[nxt] = dp[curr] - cnt[nxt] * cost + (N - cnt[nxt]) * cost
            dfs(nxt)    
            
    while True:
        N = int(input())
        if N == 0:
            break
    
        arr = [[] for _ in range(N)]
        for _ in range(N-1):
            a, b, w = map(int, input().split())
            arr[a].append((b, w))
            arr[b].append((a, w))
        
        cnt = [1] * N
        dp = [0] * N
        visited = [False] * N
        dfs_sum(0)
    
        visited = [False] * N
        dfs(0)
    
        print(min(dp))
    반응형