IT Repository

(3) LSTM 본문

RNN/Study

(3) LSTM

IT찬니 2020. 1. 17. 13:28

 

 

 

 

Vanishing Gradient of Vanilla RNN

앞서 Vanilla RNN에서 이야기했던 Gradient Vanishing 문제를 좀더 수식적으로 이해가능하게 살펴보겠습니다.

 

설명에 앞서 식을 좀더 간단하게 Visualize하기 위해서

$h_t = tanh(U \cdot x_t + W \cdot h_{t-1})$ 식을

1. 위 식에서 $x_t$와 $h_{t-1}$을 concatenation 하고 ($[x_t, h_{t-1}]$)
2. 두개의 파라미터 U와 W를 아우르는 하나의 파라미터인 새로운 W와 점곱

해서 아래와 같이 바꿔서 쓰겠습니다.
(행렬 연산을 따라가다보면 동일한 과정이라는 것을 이해할 수 있을 것입니다.)

$h_t = tanh(W \cdot [x_t, h_{t-1}])$

 

자, 이제 Vanilla RNN으로 시퀀스를 처리하는 모델을 학습시켜 봅시다.
시퀀스의 길이가 3인 경우 하나의 시퀀스는 아래와 같이 처리됩니다.

$h_{t-2} = tanh(W [x_{t-2}, h_{t-3}]) \\ h_{t-1} = tanh(W [x_{t-1}, h_{t-2}]) \\ h_t = tanh(W [x_t, h_{t-1}])$

세 개의 수식을 하나의 식에 모두 대입해서 보면 아래와 같은 수식이 성립합니다.

$h_t = tanh(W[x_t, tanh(\dots tanh(\dots h_{t-3}))])$

여기서 볼 수 있듯이 너무 많은 tanh가 합성되는 것을 확인할 수 있습니다.
tanh는 x축으로 양 끝의 Gradient가 0으로 수렴하는 함수였습니다.

(아래의 그래프를 참고해주세요.)

In [24]:
import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(-3, 3, 601)
f = np.tanh(x)
g = np.gradient(f, x)

plt.plot(x, f, "r", label="tanh(x)")
plt.plot(x, g, "g", label="d/dx tanh(x)")
plt.legend()
plt.grid()
 
 

따라서 빠르게 Gradient가 0으로 수렴해버리므로 역전파가 잘 되지 않는 문제가 발생하게 됩니다.
특히, 긴 시퀀스의 경우 더 많은 tanh를 거쳐야 하므로 Vanilla RNN은 긴 시퀀스를 학습하기가 어렵다는 결론에 도달하게 됩니다.

그러면 이제 이를 개선한 두 가지 구조인 LSTM과 GRU를 알아보고 어떠한 원리로 이를 개선할 수 있었는지 알아봅시다.

 

LSTM (Long-Short Term Memory)

LSTM의 메인 아이디어는 Information Flow를 추가하는 것입니다.
시퀀스의 길이가 길어질 경우 정보가 멀리있는 곳까지 전달되기 어려운 기존의 RNN 구조에서
Information Flow를 추가함으로써 정보가 소실되지 않고 멀리까지 갈 수 있도록 한 것입니다.

구체적으로 아래의 두 가지 방법을 구현함으로써 개선했습니다.

  1. Cell State (Information Flow)
  2. Gate
 

Cell State

Cell state라고 불리는 Information Flow가 시간에 따라 정보를 공급합니다.
남길 건 남기고, 버릴 건 버리고, 새로 추가할 건 추가하여 중요한 정보만 계속 흘러갈 수 있도록 합니다.
(남기고 버리고 추가하는 작업을 위해서 Gate 라고 불리는 Coefficient를 사용합니다.)

Vanilla RNN의 경우 이전 스텝의 상태, 즉 단기 상태인 $h_t$만을 고려합니다.
여기에 장기 상태인 Cell state를 둠으로써 단기 상태를 고려하면서 장기 상태에서 정보를 공급받을 수 있도록 하는 것입니다.
이와 같이 정보의 상태를 단기 상태와 장기 상태로 구분하는 것이 LSTM의 핵심이라고 할 수 있습니다.

Cell state에는 현재 스텝의 정보다음 스텝으로 넘겨주어야 할 정보가 함께 있습니다.
각 스텝에서의 아웃풋, 즉 Hidden state는 Cell state에 있는 현재 스텝에서 내보내야 할 정보를 가공해서 만들어지게됩니다.

 

Gate

Gate는 학습을 통해 만들어지는 0~1 사이의 계수입니다.
이 계수들을 이전 스텝의 아웃풋에 element-wise하게 곱함으로써 정보의 중요도에 따라 그 비중을 컨트롤하게 됩니다.
즉, 계수가 0이면 Gate를 닫아 정보의 흐름을 막고 계수가 1이면 Gate를 열어 모든 정보를 흘려보냅니다.

Gate Coefficient는 아래와 같이 계산됩니다.

$g_t = \sigma(W_g \cdot x_t)$

$W_g$는 Linear Transformation을 위한 파라미터 행렬, $x_t$는 인풋 벡터를 의미합니다.
즉, $g_t$는 인풋 벡터를 Linear Transformation하고, 이를 Coeffecient값으로 변환하기 위해서 Sigmoid Activation합니다.

위와 같이 Coefficient 값을 만들고 나면 아래와 같이 Cell state를 변형시킵니다.

$C'_t = g_t \ast C_t$

 

Work Flow of LSTM

이제 이전 층에서 전달된 $h_{t-1}$과 새로운 인풋 $x_t$가 Cell State와 Gate를 통해 어떻게 LSTM에서 적용되는지 알아보겠습니다.

  1. Gate coefficient를 계산한다.
    $\begin{eqnarray} f_t &=& \sigma~(W_f \cdot [h_{t-1},~x_t] + b_f) & ~~~~~\text{(Forget gate: 셀 스테이트의 불필요한 정보를 삭제)} \\ i_t &=& \sigma~(W_i \cdot [h_{t-1},~x_t] + b_i) & ~~~~~\text{(Input gate: 임시 상태에서 셀 스테이트로 추가할 정보 결정)} \\ o_t &=& \sigma~(W_o \cdot [h_{t-1},~x_t] + b_o) & ~~~~~\text{(Output gate: 이번 타임스텝에서 출력할 정보를 결정)} \end{eqnarray}$

  2. 이전 스텝의 상태와 현재 스텝의 인풋을 조합하여 임시 상태를 만든다. (= Vanilla RNN 셀의 $h_t$)
    $g_t = tanh~(W_C \cdot [h_{t-1},~x_t] + b_C)$

  3. Forget gate와 Input gate를 통해 $C_t$를 만든다. (Cell state update)
    $C_t = f_t \ast C_{t-1} + i_t \ast g_t$

  4. Output gate를 통해 $h_t$를 만든다.
    $h_t = o_t \ast tanh~(C_t)$

  5. $C_t$와 $h_t$를 다음 스텝으로 전달한다.

 

(Gate와 Cell state의 차원은 Hidden state의 차원과 동일합니다.)

 

How LSTM Solves Gradient Vanishing

위의 LSTM의 흐름도 사진에서 Cell state 부분인 $C_t$만 주목해서 생각해봅시다.

Cell state에는 $\otimes$ 와 $\oplus$ 의 두 개의 Operation만 있을 뿐 Non-linear function이 없습니다.
즉, 어떠한 타임스텝 t 에서 Cell state에 추가된 정보가 t+n 스텝에서 활용되었을 떄 tanh 함수없이 Linear하게 전달되어 Gradient Vanishing 문제를 상당부분 해결합니다.

 

Peephole Connection

Peephole Connection은 2000년에 F. Gers와 J.Schmidhuber에 의해 발표된 논문 "Recurrent Nets that and Count" 에서 제안된 LSTM의 Gate 연결 기법입니다.
아래의 사진을 먼저 보겠습니다.

각 Gate에 기존 LSTM과 다르게 Cell state와 연결된 선이 추가된 것을 확인할 수 있습니다.

기존의 LSTM은 Gating을 위해서 이전 타임스텝의 Hidden state $h_{t-1}$ 만을 입력으로 받습니다.
그러나 Peephole Connection을 통해서 Cell state를 연결하면, 단기 상태인 $h_{t-1}$뿐만 아니라 장기 상태인 $c_{t-1}$이 추가됩니다.
이는 Gate를 통해 정보를 컨트롤함에 있어서 좀 더 넓은 맥락을 고려하는 효과를 가져올 수 있습니다.

위의 그림과 아래 수식을 통해 기존 LSTM과의 차이점을 확인해보시기 바랍니다.

$\begin{eqnarray} f_t &=& \sigma~(W_f \cdot [C_{t-1},~h_{t-1},~x_t] + b_f) \\ i_t &=& \sigma~(W_i \cdot [C_{t-1},~h_{t-1},~x_t] + b_i) \\ o_t &=& \sigma~(W_o \cdot [C_{t},~h_{t-1},~x_t] + b_o) \end{eqnarray}$

'RNN > Study' 카테고리의 다른 글

(4) GRU  (0) 2020.01.17
(2) Basic of RNN - Vanilla RNN  (0) 2020.01.16
(1) Basic of RNN - Sequence Data  (0) 2020.01.16
(0) Overview of RNN  (0) 2020.01.14
Comments