Mathematics for Artificial Intelligence 10강: RNN 첫걸음

2023. 1. 6. 22:58BOOTCAMP/boostcamp AI Tech Pre-Course

시퀀스 데이터 이해하기

  • 소리, 문자열, 주가 등의 데이터를 시퀀스(sequence) 데이터로 분류한다.
  • 시계열(time-series) 데이터는 시간 순서에 따라 나열된 데이터로 시퀀스 데이터에 속한다.
  • 시퀀스 데이터는 독립동등분포(i.i.d) 가정을 잘 위배하기 때문에 순서를 바꾸거나 과거 정보에 손실이 발생하면 데이터의 확률분포도 바뀌게 된다.
    • 과거 정보 또는 앞뒤 맥락 없이 미래를 예측하거나 문장을 완성하는 건 불가능하다.

시퀀스 데이터를 어떻게 다루는가?

  • 이전 시퀀스의 정보를 가지고 앞으로 발생할 데이터의 확률분포를 다루기 위해 조건부확률을 이용할 수 있다.

P(X1,..., Xt) = P(Xt|X1,..., Xt-1) P(X1,..., Xt-1)

  • 이전 시퀀스의 정보를 가지고 앞으로 발생할 데이터의 확률분포를 다루기 위해 조건부확률을 이용할 수 있다. Xt ~ P(Xt|Xt-1,..., X1)
    • 위 조건부확률은 과거의 모든 정보를 사용하지만 시퀀스 데이터를 분석할 때 모든 과거 정보들이 필요한 것은 아니다.
  • 시퀀스 데이터를 다루기 위해선 길이가 가변적인 데이터를 다룰 수 있는 모델이 필요하다. Xt ~ P(Xt|Xt-1,..., X1) Xt+1 ~ P(Xt+1|Xt, Xt-1,... X1)
    • 고정된 길이 델타만큼의 시퀀스만 사용하는 경우 AR(델타)(Autoregressive Model) 자기 회귀모델이라고 부른다.
    • 또 다른 방법은 바로 이전 정보를 제외한 나머지 정보들을 Ht라는 잠재변수로 인코딩해서 활용하는 잠재 AR 모델이다.
    • 잠재변수 Ht를 신경망을 통해 반복해서 사용하여 시퀀스 데이터의 패턴을 학습하는 모델이 RNN이다.

Recurrent Neural Network을 이해하기

  • 가장 기본적인 RNN 모형은 MLP와 유사한 모양이다.
    • W(1), W(2)은 시퀀스와 상관없이 불변인 행렬이다. H = 델타(XW(1) + b(1)) -> 잠재변수 = 활성화함수(가중치행렬 + bias)
    • O = HW(2) + b(2)
  • RNN은 이전 순서의 잠재변수와 현재의 입력을 활용하여 모델링한다. Ht = 델타(XtWx(1) + Ht-1WH(1) + b(1))
    • 과거의 정보를 H에 넣을 수 없다.
    • 잠재변수인 Ht를 복제해서 다음 순서의 잠재변수를 인코딩하는 데 사용한다.
  • Ot = HtW(2) + b(2)
  • RNN의 역전 파는 잠재변수의 연결그래프에 따라 순차적으로 계산합니다.
    • 이를 Backpropagation Through Time(BPTT)이라 하며 RNN의 역전파 방법이다.

BPTT

  • BPTT를 통해 RNN의 가중치행렬의 미분을 계산해보면 아래와 같이 미분의 곱으로 이루어진 항이 계산된다.

ht = f(xt, ht-1, wh) and Ot = g(ht, wo).

기울기 소실의 해결책

  • 시퀀스 길이가 길어지는 경우 BPTT를 통한 역전파 알고리즘의 계산이 불안정해지므로 길이를 끊는 것이 필요하다.
    • 이를 truncated BPTT라 부른다.
  • 과거의 정보를 잃어버리기 쉽기에 몇 개의 그래디언트를 끊고, 몇개의 블록을 나눠서 백 프로 파게이션에 계산하는 것이다. 모든 t시점이 아닌 나눠서 전달하면서 그래디언트 배니싱을 해결할 수 있게 된다.
  • 이런 문제들 때문에 Vanilla RNN은 길이가 긴 시퀀스를 처리하는데 문제가 있다.
    • 이를 해결하기 위해 등장한 RNN 네트워크가 LSTM과 GRU이다.

잠재변수를 활용해서 자기 회귀모습에 대해 배웠다. 실제 시퀀스 모델을 다룰 땐, LSTM, GRU과 같은 고급된 것을 사용하는 이유는 BPTT 할 때 그래디언트가 0으로 소실되는 문제가 있기 때문이다.