LSTM 정복하기

Posted by Jeonghwan Lee on January 15, 2022 · 4 mins read

이번 포스트에서는, LSTM을 정복해보려고 한다. 나에게 LSTM은 딥러닝을 시작한 이후로 처음으로 좌절감을 맛보게 한 사악한 녀석이다. CNN과 RNN까지는 꾸역꾸역 이해했지만, LSTM은 수많은 게이트의 향연(?)으로 내 지능을 시험했다. 이번에는 이 LSTM의 차근차근 뜯어보면서, 내가 나름대로 이해했던 방식을 소개하려고 한다.
(LSTM에 대한 소개를 하기에 앞서, RNN에 대한 기본적인 이해는 했음을 가정한다.)

RNN과 장기기억 의존성 문제

LSTM (Long Short Term Memory)은 직역하면 “장단기 기억모델”쯤 될 것이다. 이 모델은 RNN이 가지고 있는 고질적인 문제인 “장기기억의존성(“Long Term Dependecy”)을 해결하기 위해 등장했다.

RNN

위 그림은 많이들 봤을, RNN의 구조이다. 현재 시점의 상태 (hidden state t)는 이전 시점의 상태의 (hidden state t_1)로 부터 영향을 받게 된다. Sequence가 길어질수록 계산 해야 하는 hiddn state도 많아질 것이다.

RNN Cell

RNN의 내부 계산 과정을 보면, 위와 같이 이루어져 있다. 결국엔 tanh 함수를 통해 최종 출력값을 내놓게 된다.

tanh

하지만 tanh함수는 위와 같이, 미분값이 0~1 사이에 있다. 또한 x의 값이 0에서 멀어질수록 미분값도 0에 가까워지게 된다. 즉 계산을 하면 할수록 기울기가 점점 작아질 확률이 높아진다. RNN모델이 처리해야 할 Sequence의 길이가 길어질수록, 기울기가 소실하여 결국에는, 이전의 정보를 기억하지 못하는 문제가 발생하는 것이다.

LSTM

LSTM

위는 LSTM의 전체적인 모델 구조이다. 벌써부터 머리가 아프다. 갑자기 요상한 연산들이 여러개 붙어버렸다. 그래도 겁먹지 말고 하나하나씩 정복해보자.

Cell state

LSTM의 핵심은 “Cell State”라고 불리는 이 “C”이다.
C는 “이 정보는 얼마만큼 중요한가?”를 결정한다. 즉 중요하다고 생각하는 정보를 담고 있는 상자라고 생각하면 된다.

수식으로는 아래와 같다.

\[h_t=tanh(c_t)\]

즉, LSTM에서 hidden state는 단순히 $c_t$에 tanh 함수를 적용한 것이다.

Output Gate

한편 $tanh(c_t)$의 정보를 다음 시점으로 “얼마나 흘려 보낼 것인가?” 역시 LSTM에서는 중요하게 다룬다. 즉 무턱대고 모든 정보를 다음 시점으로 전달하는 것이 아니라, 필요한 만큼(중요하게 생각하는 만큼)보내는 것이다. 이 역할을 담당하는 것이 “게이트”이다. 즉 게이트를 열고 닫는 만큼 정보가 흐르는 것이다.

이 게이트는 결국 최종출력인 $h_t$를 담당하므로 “output gate”라고 불린다.

output gate의 수식은 다음과 같다.

Output Gate

각 문자의 윗첨자 (o)는 output gate를 의미한다. RNN과 마찬가지로, 현재 시점의 input($X_t$)과 이전 시점의 hidden state($h_{t-1}$)에 각각의 가중치가 곱해지고 bias가 추가된다. 그리고 시그모이드(σ) 함수를 적용한다.

최종 출력값인 hidden state는 아래와 같다.

\[h_t=o⊙tanh(c_t)\]

즉 output gate의 출력값 o와 $tanh(c_t)$의 원소곱이다.
이를 해석하면, “cell state에서 중요하게 생각하는 정보 $tanh(c_t)$를, o만큼 흘려보내라는 것이다.”

한편, tanh의 출력은 -1.0~1.0 실수입니다. 이 -1.0~1.0의 수치를 그 안에 인코딩된 “정보”위 강약(정도)를 표시한다고 해석할 수 있습니다. 한편 시그모이드 함수의 출력은 0.0~1.0 실수이며, 데이터를 얼마만큼 통과시킬지를 정하는 비율을 뜻합니다. 따라서 (주로) 게이트에서는 시그모이드 함수가, 실질적인 “정보”를 지는 데이터에서는 tanh 함수가 활성화 함수로 사용됩니다.
(밑바닥부터 시작하는 딥러닝2)

지금까지의 계산 과정을 그림으로 표현하면 아래와 같다.

LSTM output gate operation

Forget Gate

정보를 기억하는 것 만큼 중요한 것이, 필요없는 정보를 “잊어버리는 것”이다. forget gate는 중요하지 않은 정보를 “망각”하는 역할을 담당한다.

Forget gate

Foget gate operation

output gate의 수식과 크게 다른 것이 없다. 윗첨자 (f)가 붙어, forget gate 임을 명시해준다.

forget의 출력값 f를 이용하여, 아래와 같이 cell state를 구할 수 있다.

\[c_t=f⊙c_{t-1}\]

이는, “이전 시점의 cell state 정보를 f만큼만 기억하라” 정도로 해석할 수 있을 것이다.

New Information Cell (새로운 기억 셀)

forget gate를 통과하며, 필요없는 정보는 삭제가 되었으니, 이제는 새로운 기억을 주입해주어야 한다.

New Information Cell

g operation

g는 o,f와 달리 시그모이드(σ)가 아닌 tanh함수를 적용하는 것을 볼 수 있다. 이는 게이트로써의 역할이 아닌, “중요한 정보”를 “추가”하는 역할을 기대하기 때문이다.
(앞서, 시그모이드 함수는 “게이트”의 역할, tanh 함수는 실질적인 “정보”를 가지게 하는 역할을 한다고 설명했다.)

여하간, f연산을 거치며, 일정량의 정보를 망각하게된 $c_t$에 g연산을 통해 새로운 정보를 주입할 수 있게 되었다.

이렇게 새롭게 추가된 정보를 담고 있는 Cell을 $\hat{C_t}$라고 하자.

Input Gate

마지막으로 g에 Input gate를 추가한다.

Input gate

Input gate의 수식은 다음과 같다.

I operation

Input gate는 “게이트” 이기 때문에 새롭게 추가되는 정보를 “얼마나 반영할 것인가?”를 결정한다.

최종적으로 만들어지는 Cell state의 수식은 다음과 같다.

\[c_t=f⊙c_{t-1}+\hat{c_t}⊙i\]

“새롭게 만들어지는 Cell state($c_t$)는 이전 Cell state ($c_{t-1}$) 에서 일정량의 정보를 망각(또는 f만큼만 기억)하지만, 새롭게 추가되는 정보 ($\hat{c_t}$)를 i만큼 반영한다.”

References

  1. Understanding LSTM Networks

  2. Deep Learning from Scratch 2