대략적인 모델 파이프라인
1. Introduction
이미지에 대한 벡터 표현을 얻기 위해 이전에는 convnet의 최상위 층의 representation을 사용하였다. 여기에는 캡션에 유용할 수 있는 정보가 손실될 수 있는 문제가 생긴다. 그래서 보다 낮은 수준의 representation을 사용하면 이 정보를 보존하는데 도움이 된다. 그러나 이러한 기능을 사용하려면 중요한 정보로 모델을 control 하는 메커니즘이 필요하다.
이 논문에서는 공통 프레임워크를 (모델 구조 동일)사용하는 2가지 접근 방식을 제안한다.
- “hard” attention mechanism : variational lower bound을 최대화하여 확률론적 훈련
- “soft” attention mechanism. : 표준 역전파 방법으로 훈련
또한 attention의 장점 중 하나가 모델이 "보는" 것을 시각화하는 능력이라는 점을 보여준다.
2. Background
Multinuolli Distribution
베르누이 분포의 확장판이다. 베르누이 분포는 0 이나 1(또는 -1 이나 1)이 나오는 확률 변수의 분포였다.
Multinuolli 분포는 1부터 K까지의 K개의 정수 값 중 하나가 나오는 확률 분포이다.
Multinuolli 분포의 모수 θ는 (θ1, ..., θK) 의 형태의 벡터이고 각각의 값은 0과 1 사이이면 모든 요소의 합이 1이다. 즉 소프트맥스 함수를 통과한 것과 같다. 이러한 확률을 기반으로 출력이 결정되는 것 같다.
확률 기반의 hard attention 메커니즘에서 사용된다.
3. Image Caption Generation with Attention Mechanism
Model Details
앞서 설명한 2가지 접근 방식을 소개하고 주요 차이점은 Φ function이다.
Encoder : Convolutional Features
모델은 하나의 단일 원본 이미지를 입력 받고 캡션 (1, K) 인코딩된 시퀀스 y을 출력한다.
K : vocab size
C : length of cation
annotation vector라고 부르는 feature vector 세트를 추출하기 위해 convnet을 사용한다. 이는 L개의 벡터를 생성하며, 각 벡터는 이미지의 일부에 해당하는 D차원의 representation이다.
feature vector와 2D 이미지 부분 간의 대응성을 얻기 위해 하위 conv layer에서 feature을 추출한다. 이를 통해 디코더는 모든 feature vector의 하위 집합을 선택하여 이미지의 특정 부분에 선택적으로 초점을 맞출 수 있다.
Decoder : LSTM
매 time step마다 이전의 hidden state 및 이전에 생성된 단어를 조건으로 하나의 단어를 생성하는 LSTM을 사용한다.
context vector $\hat{z}_t$는 시간 t에 입력된 이미지의 표현이다.
- 서로 다른 이미지 위치에서 추출된 feature vector ($a_i$)로 부터 $\hat{z}_t$를 계산한다.
- 각 위치 i에 대해 다음 단어를 생성하기 위해 올바른 위치에 초점을 맞추었는지에 대한 확률 또는 상대적 중요석으로 해석될 수 있는 가중치 $α_i$를 생성한다.
- 가중치 $α_i$는 이전 hidden state에 대해 조건화된 MLP를 사용하는 attention 함수에 의해 계산된다.
- $f_{att}$가 attention 함수이고 이를 통해 계산된 $e_{ti}$에 소프트맥스 함수를 사용하여 가중치 αi를 계산한다.
- context vector $\hat{z}_t$는 다음과 같이 계산된다.
이부분이 수식이 복잡한데 우리가 알고있는 attention의 원리를 생각해보자
Attention(Q, K, V) = Attention value
Q = Query : t 시점의 디코더 셀에서의 은닉 상태
K = Keys : 모든 시점의 인코더 셀의 은닉 상태들
V = Values : 모든 시점의 인코더 셀의 은닉 상태들
여기서 Q가 $h_{t-1}$에 해당하고 $a_i$가 K에 해당한다. Q와 K를 dot product하여 어텐션 스코어를 구하고 이러한 어텐션 스코어의 모음을 $e_{ti}$라고 한다. 그런다음 이를 소프트맥스 함수를 적용하여 attention distribution $α_i$을 구한다.
이렇게 구해진 $α_i$을 $a_i$와 가중합 하여 attention value을 구한다. (논문에서는 context vector을 구하는 식 (6)에 해당)
LSTM의 초기 memory state (c0)와 hidden state (h0)은 feature vector의 평균을 입력으로 하는 각각의 분리된 MLP로 계산한다.
그런 다음 출력 단어 확률을 계산하기 위해 deep RNN(Pascanu et al., How to construct deep recurrent neural networks, In ICLR, 2014)을 사용한다.
이전 타임스텝까지 생성한 단어들 $y_{t-1}$, hidden state $h_t$, context vector $\hat{z}_t$을 입력으로 받아서 확률을 계산한다.
L과 E는 학습되는 가중치이다.
4. Learning Stochastic "Hard" vs Deterministic "Soft" Attention
attention model $f_{att}$을 소개한다.
이 논문이 2015년에 나왔고 그 당시 GPU 등 컴퓨팅 자원의 한계가 있어서 이러한 것을 극복하고자 저자들이 제안한 learning algorithm이라고 생각하면 된다. 현재 시점에서 코드로 구현할 때는 딱히 신경쓰지 않기 때문에 그냥 짚고만 넘어간다.
Stochastic Hard Attention
$s_t$ : t번째 단어를 생성할 때 모델이 focus할 부분을 결정하는 위치 변수
$s_{t,i}$ : L 중 i번째 위치가 시각적 특징을 추출하는 데 사용되는 위치인 경우 1로 설정되는 지표 (one-hot)
8번 수식은 잘 이해하지 못했는데 어쨋든 앞에서 attention을 통해 context vector을 구하는 방법을 생각해보면 $α_{t,i}$로 부터 $s_{t,i}$(one-hot으로 변경된 가중치)을 구하고 이를 feature vector인 $a_i$와 가중합을 하여 $\hat{z}_t$을 구하는 것이다.
$L_s$는 feature vector a가 주어졌을 때 문장 y를 출력해내는 log-likelihood이다. 이것을 maximizing 하는것이 논문에서 정의한 문제이기도 하다.
가중치 행렬 W에 대한 $L_s$의 differential을 나타내는 수식이다.
식 11과 바뀐 부분은 s를 multinoulli 분포를 통해 랜덤 샘플링하여 differential을 근사한다는 것만 차이가 있다. 왜 이렇게 하냐면 그 당시 컴퓨팅 자원의 한계로 일일히 다 계산하면 시간이 엄청 오래걸리기 때문이다. 이 알고리즘 앞에 확률적이라는 수식어가 붙는 이유도 여기서 랜덤 샘플링을 하기 때문이다.
랜덤 샘플링을 하게되면 분산이 커지게 되는데 이를 줄이기 위해서 저자들은 이동 평균을 사용하였다.
이를 적용한 최종 수식이 다음과 같다
파라미터 λ은 cross-validation으로 설정되는 하이퍼파라미터이다.
Deterministic "Soft" Attention
context vector을 구할 때 앞에서 가중치 벡터에서 랜덤 샘플링을 한 다음 가중합을 계산하는 것이 아니고 우리가 원래 알고 있는 Bahdanau 방식으로 계산하는 것이다. 즉 랜덤 샘플링 부분만 제외하고 나머지는 동일하다고 볼 수 있다.
5. Training Procedure
- Flickr8k, Flickr30k, MS COCO dataset 사용
- feature vector을 생성하기 위해 ImageNet에 pretrained된 VGGnet을 finetuning 없이 사용
Experiments
5 reference senteces per image & fix vocab size of 10000
hard
soft
qualitative analysis을 보면 문장을 생성할 때 모델이 어디에 집중했는지 확인할 수 있다.
아무래도 2015년에 발행된 논문이다보니 디코더로 LSTM을 사용했기 때문에 트랜스포머 기반의 디코더를 사용하면 성능이 더 좋지 않을까란 생각이 든다. 아마도 그런 논문이 있을 것 같은데 코드 구현을 마친 후에 한 번 찾아봐야겠다.