본문 바로가기
알고리즘 문제 풀이

[백준 11812] K진 트리

by 다빈치코딩 2023. 11. 6.

목차

    반응형

    문제 출처 : https://www.acmicpc.net/problem/11812

     

    11812번: K진 트리

    첫째 줄에 N (1 ≤ N ≤ 1015)과 K (1 ≤ K ≤ 1 000), 그리고 거리를 구해야 하는 노드 쌍의 개수 Q (1 ≤ Q ≤ 100 000)가 주어진다. 다음 Q개 줄에는 거리를 구해야 하는 두 노드 x와 y가 주어진다. (1 ≤ x, y

    www.acmicpc.net

     

    문제 이해하기

     

    최소 공통 조상(LCA)를 찾는 문제 입니다. 다만 일반적인 방법으로 풀 수 없습니다. 왜냐하면 메모리 사용을 최소로 해야 풀리는 문제이기 때문입니다. 그냥 LCA를 푸는 방식으로 풀게되면 메모리 초과를 경험하게 됩니다.

    결국 이 문제는 K진 트리의 특성을 이용하여 각 노드의 부모와 깊이를 찾아 해결해야 합니다. 다행히 이 문제는 적은 에너지 방법을 사용하는데 이 방법은 왼쪽부터 차례대로 추가하는 방식입니다. 이것을 통해 이 트리는 규칙을 가지고 있고, 계산만 잘 한다면 충분히 노드의 부모와 깊이를 알 수 있습니다.

    노드의 깊이 구하기

    먼저 노드의 깊이를 구하는 경우를 생각해 보겠습니다. 노드들의 시작이 1인 경우와 0인 경우를 따져보았을 때 저는 0으로 시작하는 것이 계산이 편했습니다. 자신이 편한 숫자를 시작값으로 트리를 구성하면 됩니다. 아마 계산이 1이 더 편했다면 1을 시작으로 했을 것입니다.

     

     

    노드의 개수는 깊이가 더해질 때마다 K배 늘어 납니다. 위 그림처럼 K가 3인 경우를 생각해 보면 각 깊이의 시작값을 계산할 수 있습니다. 루트를 0이라고 하면 1번째 깊이는 1이 시작이 되고, 두 번째 깊이는 3을 더한 4가 시작이 됩니다. 3번째 깊이는 4에다가 9를 더한 13이 됩니다. 즉 i 번째 깊이의 시작은 K ** i 값이 됩니다. 즉 노드의 깊이는 K ** i를 더해주면서 깊이를 찾다가 노드 값이 넘어가는 순간이 깊이가 되는 것입니다. 이것은 세그먼트 트리에서 트리의 깊이를 구하는 방법과 유사하기 때문에 쉽게 이해가 될 것으로 생각됩니다.

    부모 노드 구하기

    부모의 노드를 구하는 방법은 K로 나누어보면 확인이 됩니다. 위 그림에서 4의 부모는 3으로 나눈 몫인 1입니다. 5의 부모 역시 몫이 1이 됩니다. 문제는 6은 3으로 나누었을 때 2가 되기 때문에 부모의 값이 일치하지 않습니다. 규칙을 살펴보니 노드 값에다가 1을 빼서 K로 나누어 몫을 구해주는 것이 부모 노드를 찾는데 더 효과적으로 판단됩니다. 즉 4의 부모는 (4 - 1) / 3으로 계산하여 1이 되고, 6 역시 (6 - 1) / 3으로 1이 되는 것을 알 수 있습니다.

    코드 작성하기

    중요 로직을 알아보았으니 실제 코드를 작성해 보겠습니다.

    입력 받기

    먼저 입력을 받아보도록 하겠습니다.

    import sys
    input = sys.stdin.readline
    
    mii = lambda : map(int, input().split())
    
    N, K, Q = mii()
    
    for _ in range(Q):
        x, y = mii()
        print(get_dist(x - 1, y - 1))
    

    입력되는 값이 많기 때문에 input을 readline으로 받아 속도를 빠르게 해주었습니다. 그리고 노드의 개수 N, K진 트리의 K값, 입력의 개수 Q를 입력 받았습니다.

    다음 Q개의 줄에는 x, y를 입력받아 거리를 구해줍니다. 거리는 두 노드의 최소 공통조상(LCA)를 찾아 노드까지의 거리를 계산하면 됩니다. 제가 계산한 노드의 시작은 0부터 시작하기 때문에 입력받은 x, y의 값에 1을 빼주었습니다.

    거리 구하기

    def get_dist(x, y):
        if K == 1:
            return abs(y - x)
            
        x_depth = get_depth(x)
        y_depth = get_depth(y)
    
        if x_depth < y_depth:
            x_depth, y_depth = y_depth, x_depth
            x, y = y, x
    
        dist = 0
        while True:
            if x_depth == y_depth:
                break
            
            x = get_parent(x)
            x_depth -= 1
            dist += 1
    
        while True:
            if x == y:
                break
    
            x = get_parent(x)
            y = get_parent(y)
            x_depth -= 1
            y_depth -= 1
            dist += 2
        return dist
    

    거리를 구하는 함수는 LCA를 구하는 함수를 조금 응용하였습니다. 먼저 K값이 1인 경우는 계산할 필요 없이 두 노드의 차를 구해주면 됩니다. 노드의 값이 바로 깊이와 같기 때문에 따로 깊이를 구할 필요가 없습니다.

    다음으로 두 노드의 깊이를 같게하는 부분에서는 거리가 1씩 늘어나고, 최소 공통 조상을 찾는 부분은 두 노드가 같이 움직이기 때문에 거리가 2씩 늘어나게 됩니다.

    노드의 깊이 구하기

    LCA에서 노드의 깊이는 DFS 알고리즘으로 구해주었다면 여기서는 계산으로 구해줍니다. 노드의 깊이를 K배씩 늘려주어 처음으로 구하려는 노드의 값보다 큰 경우에 노드의 깊이를 알 수 있습니다.

    def get_depth(x):
        if x == 0:
            return 0
        
        i = 0
        size = 0
        while True:
            i += 1
            size += K ** i
            if x <= size:
                return i
    

    노드의 부모 구하기

    노드의 부모는 노드의 값에 1을 빼고, K로 나누어주면 쉽게 구할 수 있습니다.

    def get_parent(node):
        if node <= 1:
            return 0
        
        return (node - 1) // K
    

    전체 코드

    그럼 전체 코드를 다시한번 확인해 보겠습니다.

    import sys
    input = sys.stdin.readline
    
    mii = lambda : map(int, input().split())
    
    N, K, Q = mii()
    
    def get_depth(x):
        if x == 0:
            return 0
        
        i = 0
        size = 0
        while True:
            i += 1
            size += K ** i
            if x <= size:
                return i
    
    def get_parent(node):
        if node <= 1:
            return 0
        
        return (node - 1) // K
    
    def get_dist(x, y):
        if K == 1:
            return abs(y - x)
            
        x_depth = get_depth(x)
        y_depth = get_depth(y)
    
        if x_depth < y_depth:
            x_depth, y_depth = y_depth, x_depth
            x, y = y, x
    
        dist = 0
        while True:
            if x_depth == y_depth:
                break
            
            x = get_parent(x)
            x_depth -= 1
            dist += 1
    
        while True:
            if x == y:
                break
    
            x = get_parent(x)
            y = get_parent(y)
            x_depth -= 1
            y_depth -= 1
            dist += 2
        return dist
    
    for _ in range(Q):
        x, y = mii()
        print(get_dist(x - 1, y - 1))
    
    반응형