저번 주 이번 주 재귀와 백트래킹을 공부했기에 글을 쓴다.
근데 사실 글을 쓰고 있는 지금도 내가 재귀와 백트래킹을 제대로 이해하고 있는지 모르겠다.
확실한 건 백트래킹 너무 어려웠는데 이번 기회에 감은 잡았다.. 정도..?
재귀
먼저 재귀를 살짝 짚고 넘어가자면
재귀는 반복되는 계산에 주로 쓰이며 자신을 호출하여 점점 더 작은 계산을 한다.
작은 계산의 답이 어느 일정 값이 수렴하면 그 값을 return 하며 답을 구한다.
재귀는 스택 구조라 호출될수록 메모리가 계속해서 쌓이고
함수가 return되면 그 함수가 갖고 있었던 메모리 또한 비워지게 되기에,
그때서야 메모리를 반환하게 된다.
즉 return 조건을 만나기 전까지 재귀는 계속해서 메모리를 쌓기에
아무 생각 없이 재귀를 쓰다간 메모리 초과가 날 것이다.
재귀의 구조를 시각화하면 다음과 같다.
별 찍기 - 10 https://www.acmicpc.net/problem/2447
이 문제를 두가지 방법으로 풀었다.
## 풀이 1
k = int(input())
def wornl1(cnt,start):
if cnt == k:
return start
answer = []
for j in range(3):
for st in start:
if j == 1:
answer.append(st + ' ' * cnt+ st)
else:
answer.append(st * 3)
return wornl1(cnt*3,answer)
start = ['***','* *', '***']
if k == 3:
for ans in start:
print(ans)
else:
answer = wornl1(3,start)
for ans in answer:
print(ans)
풀이 1번은 가장 작은 값부터 호출해서 작은 값 -> 큰 값을 호출하는 구조이다.
함수는 return 하는 즉시 메모리가 반환되기에 작은 값부터 시작하면 메모리적으로 좀 더 효율적일까 봐 이렇게 짰다..
호출되는 재귀 사진과 달리 일방통행 구조라고 생각하면 된다. (검정색 선 라인)
## 풀이 2
k = int(input())
def wornl1(cnt):
if cnt == 3:
return ['***','* *', '***']
lines = wornl1(cnt//3)
answer = []
for j in range(3):
for line in lines:
if j == 1:
answer.append(line + ' ' * (cnt//3) + line)
else:
answer.append(line * 3)
return answer
answer = wornl1(k)
for ans in answer:
print(ans)
풀이 2는 큰 값으로 시작해 큰 값 -> 작은 값으로 계속해서 호출된다 (lines = wornl1(cnt//3),
if cnt == 3: 은 호출의 끝으로 (가장 작은 값) 여기서 return 한 값은 wornl1(9)에서 쓰이고, wornl1(9)에서 최종적으로 계산되는 값은 return 되어 wornl1(27)에서 쓰인다.
즉, 작은 값에서 얻은 return 값 -> 큰 값으로 계속해서 return 하여 최종적으로 큰 값에서 처음 호출한 곳에서 문제의 정답을 return 하게 되는 구조이다.
두 가지 풀이로 짠 이유는 스택이 쌓일 때의 메모리를 생각해서 2번으로 짰다가 1번으로 다시 짠 건데, 백준 돌려보니 메모리는 거의 비슷했다. 크흠...;; 왜인지는.. 동혁님한테 여쭤본다 해놓고 .. 내가 말을 잘 못해서.. 제대로 못여쭤봤다.ㅎ..ㅎ 다시 물어봐야지
백트래킹
쉽게 말하면 가지 치기다. dfs탐색을 하면서, 문제의 조건에 맞지 않으면 더 깊숙이 들어가기 전에 가지를 쳐버리는 것이다. dfs를 재귀 구조로 짜면 백트래킹 구조가 된다.
한 가지 주의해야 할 점이 있다면, 깊게 들어갔다가 "오잉? 여기가 아니네? " 또는 "끝까지 다 탐색했어~ 다른곳도 탐색해볼까~?" 하고 이전 함수로 빠져나올 때는 해당 값을 원상 복구시켜줘야 한다는 점이다.
백준에 있는 백트레킹 중 대표적인 문제 N과 M 시리즈~
난 그 중에 15650 을 풀었다.
N과 M (2) - https://www.acmicpc.net/problem/15650
# 중복없이 N개를 고른 수열 즉 조합을 찾아라 이 말인가~~?
N, M = map(int,input().split())
lst = [0 for _ in range(M)]
def backtracking(lst,idx,start):
if lst[-1] != 0:
print(" ".join(map(str,lst)))
return
for i in range(start,N+1):
lst[idx] = i
backtracking(lst,idx+1,i+1)
lst[idx] = 0
backtracking(lst,0,1)
만약 이 코드에서 lst[idx] = 0 을 없애면
이런 끔찍한 사태가 벌어진다.
연산자 끼워넣기 -https://www.acmicpc.net/problem/14888
연산자 끼워넣기 문제도 N과 M (2) 와 아주 비슷한 구조로 풀 수 있다.
N = int(input())
num = list(map(int,input().split()))
cal = list(map(int,input().split()))
#cal = ['+','-','*','/']
answer = []
def backtracking(idx,ans):
if idx == N:
answer.append(ans)
return
if cal[0]:
cal[0] -= 1
backtracking(idx+1, ans + num[idx])
cal[0] += 1
if cal[1]:
cal[1] -= 1
backtracking(idx+1,ans - num[idx])
cal[1] += 1
if cal[2]:
cal[2] -= 1
backtracking(idx+1, ans * num[idx])
cal[2] += 1
if cal[3]:
cal[3] -= 1
backtracking(idx+1, int(ans / num[idx]))
cal[3] += 1
backtracking(1,num[0])
#print(answer)
answer.sort()
print(answer[-1])
print(answer[0])
재귀로 호출하고 다시 빠져나올 때 원상복구 시켜주기~~
다음주는 dp!