RNN(Recurrent Neural Network)
-
RNN은 자연어 처리 및 시계열 데이터 등의 가변적인 길이의 sequence 데이터를 입력으로 받아 매 time step마다 output을 내는 모델이다.
- 위의 식과 같이 RNN은 이전 time step의 output인 $h_{t-1}$과 현재 input $x_t$를 input으로 받아 현재의 hidden state $h_t$를 출력한다.
- 이 때, 모든 time step에서는 동일한 Parameter $\theta$를 사용한다.
- RNN에서 Activation 함수로 주로 Tanh를 사용한다. Tanh는 zero-centered한 특성으로 인해 학습을 더 안정적으로 동작하는데 바람직하기 때문이다.
- 첫 번째 input이 들어왔을 때, 이전 시점의 hidden state는 존재하지 않기에 보통 0으로 채워진 벡터를 입력으로 준다.
-
Non Zero-centered의 문제점
- 대표적인 Non Zero-centered 함수인 Sigmoid를 예시로 들어보자.
- Non Zero-centered 특성을 가진 함수는 역전파 시에 문제가 발생한다.
- 역전파 과정에서 가중치의 변화율(Gradient)은 $\frac{dL}{dw} = \frac{dL}{df} \times \frac{df}{dw}$과 같이 표현할 수 있다.
- 위 그림에서 볼 수 있듯이, Sigmoid의 미분 값은 항상 양수이다. 즉, $\frac{df}{dw}$>0이므로 $\frac{dL}{df}$의 부호는 항상$\frac{dL}{df}$의 부호에 따라 결정된다.
- 또한, Sigmoid의 출력값이 항상 양수(0 ~ 1 사이)이므로, 역전파 과정에서 모든 가중치의 그래디언트가 같은 방향(부호)으로 변하는 경향이 있다. 이로 인해 가중치가 Zig-Zag 패턴으로 최적해를 찾게 되어 학습이 비효율적이다.
- 이런 RNN을 여러 Layer로 쌓아 사용하는 Multi-layer RNN, 이 후 시점의 정보들도 반영하는 Bidirectional RNN등으로 응용할 수 있다.
-
RNN의 형태
- one-to-one
- 입력과 출력이 모두 sequence가 아닌 단일한 time step의 데이터로 이루어진 형태(ex. CNN)
- one-to-many
- 단일 time step을 입력으로 여러 time step에 걸쳐 생성하는 형태(ex. Image captioning)
- many-to-one
- Sequence 형태의 입력을 받아 단일 time step의 출력을 생성하는 형태(ex. 감정 분류)
- many-to-many
- Sequence 형태의 입력을 받아 Sequence 형태의 출력을 생성하는 형태(ex. Language Model, Video 분류, 기계 번역)
- one-to-one
Language Model
-
Language Model은 many-to-many 형태로 하나의 input token이 들어왔을 때, delay없이 매 time step마다 하나 씩 output(다음 단어의 예측 값)을 출력하고, 이 output이 다음 time step의 input으로써 사용되는 Auto-Regressive Model이다.
-
Language Model의 추론 과정
- 각 문자를 One-hot encoding으로 입력
- 입력 $x_t$으로 hidden state $h_t$ 계산($h_t=tanh(W_{hh}h_{t-1} + W_{xh}x_t)$)
- hidden state $h_t$에서 출력 $y_t$를 계산($y_t = W_{hy}h_t$)
- Softmax를 통해 각 문자의 확률 계산(여기서 나온 문자가 다음 time step의 입력으로 사용)
- Loss 계산 및 학습
- Language Model의 학습은 각 time step에서의 출력(output)과 정답(GT) 사이의 Loss를 계산하고, 이를 기반으로 매 time step마다 파라미터($W_{xh}, W_{hy}, W_{hh}$)를 업데이트하는 방식으로 이루어진다.
- 이런 방식의 학습법을 Backpropagation Through Time (BPTT)라고 한다.
- 하지만 BPTT는 긴 sequence에 대해 계산 비용과 메모리 사용량이 매우 크다는 문제점이 있다. 이를 해결하기 위해 긴 sequence를 여러 개의 작은 단위(Chunk)로 나누어 학습하는 방식인Truncated BPTT가 고안되었다.
- Truncated BPTT에서는 하나의 Chunk가 Backward를 완료한 후, 은닉 상태($h_t$)만 메모리에 유지한 채 다음 Chunk로 넘겨 사용한다. 이를 통해 메모리 사용을 최적화할 수 있었다.