[백준 2295] 세 수의 합
문제 출처 : https://www.acmicpc.net/problem/2295
집합에 포함된 세 수를 더한 결과가 집합 내에 있는 가장 큰 수가 되는 경우를 찾는 문제 입니다. 세 수 x, y, z와 그 결과로 가장 큰 수 k를 찾아야 합니다. 이 때 꼭 기억해야 하는 부분은 x, y, z, k의 값이 서로 같아도 된다는 부분 입니다. 저는 세 수가 같을수는 없다고 생각하고 아무리 풀어도 틀리다고 나와 문제를 다시 제대로 읽어보니 숫자가 중복되어도 상관 없다는 부분을 찾을 수 있었습니다. 저와 같은 실수를 하지 마시기 바랍니다.
아이디어
세 개의 수를 더하기 때문에 시간복잡도가 높을 것으로 생각이 듭니다. 하지만 조금만 더 생각해보면 시간복잡도를 줄일 수 있는 아이디어가 있습니다. 우리가 구해야 하는 공식은 다음과 같습니다.
x + y + z = k
이것을 조금만 수정하면 다음과 같이 됩니다.
x + y = k - z
세 개의 수를 더하는 공식에서 두 수의 합과 두 수의 차를 비교하는 문제로 바뀌었습니다. 여기에 정렬만 잘 되어 있다면 가장 큰 결과를 찾기 위해 여러번 계산할 필요 없이 k를 가장 크게 만들고, z를 가장 작게 만든다면 첫 번째 만난 결과가 가장 큰 수가 되는 것입니다.
문제의 예제를 보면 {2, 3, 5, 10, 18} 에서 가장 큰 수가 18이고, 가장 작은 수가 2 입니다. 18을 k로 두고 2를 z로 생각하고 문제를 해결해 보는 것입니다. 먼저 18 - 2 인 16을 x + y로 구할 수 있는지 확인해 봅니다. 결과가 없기 때문에 이제 z를 2의 다음수인 3으로 바꿔줍니다. 18 - 3인 15를 x + y를 찾아 봅니다. 그럼 x가 5, y가 10일 때 15의 결과를 얻습니다. 그럼 예제의 답인 3 + 5 + 10 = 18을 얻을 수 있습니다.
만약 이래도 결과가 없다면 이제 k값을 18에서 10으로 바꿔주어 다시 결과를 확인하면 됩니다. 가장 먼저 결과를 낸 숫자가 가장 큰 결과라는 것을 이해하면 쉽게 해결할 수 있습니다.
이 문제는 두 가지 방법으로 풀 수 있습니다. 첫 번째로 이분탐색을 이용하여 문제를 해결하는 방법입니다. 두 번째로 set 이라는 집합 자료형을 사용하여 문제를 해결하는 방법입니다. set은 해시 테이블을 가지고 있어 set을 검색하는 in 을 사용 시 평균 시간복잡도가 O(1) 입니다. 즉 이분 탐색이 필요없이 검색을 할 수 있습니다.
코드 작성하기
그럼 문제를 직접 풀어보도록 하겠습니다. 먼저 이분탐색으로 풀어보고 다음으로 set을 사용하여 풀어보도록 하겠습니다.
이분 탐색으로 문제 해결하기
먼저 이분 탐색으로 어떻게 문제를 해결하는지 알아보겠습니다.
입력 받기
N = int(input())
arr = []
for _ in range(N):
temp = int(input())
arr.append(temp)
arr.sort()
첫 번째 입력은 자연수 N을 입력 받습니다. 다음으로 N개의 자연수를 입력 받습니다. 입력 받은 자연수를 arr이라는 리스트에 담아 두었습니다. 그리고 arr를 정렬하였습니다. 위에서 가장 큰 수 k에서 가장 작은 수 z를 빼는 방식으로 세 수의 합을 구한다고 했습니다. 따라서 arr을 정렬해 두어야 합니다. 문제에서 오름차순으로 입력된다는 부분이 없기 때문에 정렬을 해놓지 않으면 안됩니다.
두 수의 합 구하기
sum_arr = []
for x in range(N - 1):
for y in range(x, N):
sum_arr.append(arr[x] + arr[y])
sum_arr.sort()
두 수의 합을 미리 구해두도록 하겠습니다. x와 y의 합을 미리 구하는 이유는 정답을 찾기 위해 매 번 계산하는 것이 아니라 한 번 계산해두고 바로바로 가져다 쓰기 위함 입니다. 이 때 세 수는 중복이 될 수 있기 때문에 y의 시작이 x부터 입니다. 보통 중복되지 않은 경우 y의 시작 값은 x + 1로 합니다.
두 수의 합으로 만들어진 sum_arr을 정렬하여 작은 경우부터 체크할 수 있도록 하였습니다.
결과 출력하기
def solve():
for k in range(N - 1, -1, -1):
for z in range(k + 1):
rst = arr[k] - arr[z]
if binary_search(rst):
return arr[k]
print(solve())
k값은 가장 큰 값부터 시작합니다. 따라서 range 범위가 가장 끝인 N - 1부터 0번째 까지 돌도록 하였습니다. 다음으로 z값은 앞에서부터 k까지 탐색합니다.
rst에는 k값에서 z값을 빼준 결과가 있습니다. 이분 탐색을 통해 rst 값이 두 수의 합이 맞다면 결과를 리턴하고 함수를 종료합니다.
이분 탐색
def binary_search(rst):
start = 0
end = len(sum_arr) - 1
while start <= end:
mid = (start + end) // 2
if sum_arr[mid] == rst:
return True
elif sum_arr[mid] < rst:
start = mid + 1
else:
end = mid - 1
return False
이분 탐색 부분 입니다. sum_arr을 이분 탐색을 통해 결과를 얻습니다. sum_arr 리스트에는 x + y의 값이 들어 있습니다. 그리고 rst 값은 k - z 값이 들어 있습니다. 즉 rst를 찾으면 x + y + z의 가장 큰 값을 찾은 것과 같습니다. 이분 탐색에 대해서는 따로 설명하지 않겠습니다.
집합(set) 사용하기
set 자료형을 사용하여 문제를 해결하는 방법을 알아보겠습니다. set 은 유용한 자료형이기 때문에 사용법을 기억해 두면 좋습니다. 많은 부분이 겹치기 때문에 설명이 필요한 부분만 보겠습니다.
set으로 두 수의 합 구하기
sum_set = set()
for x in range(N - 1):
for y in range(x, N):
sum_set.add(arr[x] + arr[y])
sum_set이라는 집합을 선언하고 add를 통해 두 수의 합을 추가하였습니다. set은 입력값에 대한 순서가 없습니다. 해시 테이블로 빠르게 검색이 가능하기 때문에 앞에서처럼 정렬해둘 필요가 없습니다. 그리고 set은 중복된 값이 없습니다. 중복값을 제외하고 싶을 때 set 자료형으로 바꿔주면 중복값을 모두 삭제할 수 있습니다.
문제 해결하기
def solve():
for k in range(N - 1, -1, -1):
for z in range(k + 1):
rst = arr[k] - arr[z]
if rst in sum_set:
return arr[k]
print(solve())
바뀐 부분은 sum_set에서 그냥 rst가 존재하는지만 따져준 부분입니다. 앞에서는 binary_search 라는 함수를 따로 만들어 주었는데 set은 그럴 필요가 없어 편리 합니다.
전체 코드
이분 탐색 풀이
그럼 이분 탐색을 통해 문제를 해결하는 전체 코드를 보겠습니다.
N = int(input())
arr = []
for _ in range(N):
temp = int(input())
arr.append(temp)
arr.sort()
sum_arr = []
for x in range(N - 1):
for y in range(x, N):
sum_arr.append(arr[x] + arr[y])
sum_arr.sort()
def binary_search(rst):
start = 0
end = len(sum_arr) - 1
while start <= end:
mid = (start + end) // 2
if sum_arr[mid] == rst:
return True
elif sum_arr[mid] < rst:
start = mid + 1
else:
end = mid - 1
return False
def solve():
for k in range(N - 1, -1, -1):
for z in range(k + 1):
rst = arr[k] - arr[z]
if binary_search(rst):
return arr[k]
print(solve())
set 자료형 사용
set 자료형을 사용한 전체 코드를 확인해 보겠습니다. 어느 방법이 좋은지는 알 수 없습니다. 자신에게 편리한 방법으로 진행하면 됩니다.
N = int(input())
arr = []
for _ in range(N):
temp = int(input())
arr.append(temp)
arr.sort()
sum_set = set()
for x in range(N - 1):
for y in range(x, N):
sum_set.add(arr[x] + arr[y])
def solve():
for k in range(N - 1, -1, -1):
for z in range(k + 1):
rst = arr[k] - arr[z]
if rst in sum_set:
return arr[k]
print(solve())