수식을 차례로 보면 됨
Attention weight, $w$ = $softmax(h_{t}^{dec}\cdot W_{a} \cdot h_{1:m}^{enc^{T}})$, $w \in \mathbb{R}^{bs \times 1 \times m}$
$h_{t}^{dec}$ : decoder의 t번째 hidden state 값, $h_{t}^{dec} \in \mathbb{R}^{bs \times 1 \times hs}$
$W_{a}$ : query를 위한 linear layer, $W_{a} \in \mathbb{R}^{hs \times hs}$
$h_{1:m}^{enc^{T}}$: encoder의 모든 timestep의 결과값을 transpose한 것, $h_{1:m}^{enc^{T}} \in \mathbb{R}^{bs \times hs \times m}$
즉, decoder의 한 결과값과 encoder의 모든 결과값에 대한 내적(유사도)를 추출하는 것
<aside> 💡 $a \cdot b = |a||b|cos \theta$ 이므로, 두 벡터가 방향이 비슷할수록 값이 더 커짐 ⇒ 유사도가 클수록 내적 값이 커진다
</aside>
query = self.linear(h_t_tgt)
# |query| = (batch_size, 1, hidden_size)
weight = torch.bmm(query, h_src.transpose(1, 2))
# |weight| = (batch_size, 1, length)
weight = self.softmax(weight)
Context vector, $c = w \cdot h_{1:m}^{enc}$, $c \in \mathbb{R}^{bs \times 1 \times hs}$
$w$ : attention weight, $w \in \mathbb{R}^{bs \times 1 \times m}$
$h_{1:m}^{enc}$: encoder의 모든 timestep의 결과값
즉, attention weight와 encoder의 모든 timestep의 값을 곱한 것
context_vector = torch.bmm(weight, h_src)
# |context_vector| = (batch_size, 1, hidden_size)
decoder에서 context vector값 더하기, $\tilde{h}{t}^{dec} = tanh([h{t}^{dec};c]\cdot W_{concat})$, $\tilde{h}_{t}^{dec} \in \mathbb{R}^{bs \times 1 \times hs}$
최종 출력값, $\hat{y}{t} = softmax(\tilde{h}{t}^{dec} \cdot W_{gen})$, $\hat{y}_{t} \in \mathbb{R}^{bs \times 1 \times |V|}$
class Attention(nn.Module):
def __init__(self, hidden_size):
super(Attention, self).__init__()
self.linear = nn.Linear(hidden_size, hidden_size, bias=False)
self.softmax = nn.Softmax(dim=-1)
def forward(self, h_src, h_t_tgt, mask=None):
# |h_src| = (batch_size, length, hidden_size)
# |h_t_tgt| = (batch_size, 1, hidden_size)
# |mask| = (batch_size, length)
query = self.linear(h_t_tgt)
# |query| = (batch_size, 1, hidden_size)
weight = torch.bmm(query, h_src.transpose(1, 2))
# |weight| = (batch_size, 1, length)
if mask is not None:
# Set each weight as -inf, if the mask value equals to 1.
# Since the softmax operation makes -inf to 0,
# masked weights would be set to 0 after softmax operation.
# Thus, if the sample is shorter than other samples in mini-batch,
# the weight for empty time-step would be set to 0.
weight.masked_fill_(mask.unsqueeze(1), -float('inf'))
weight = self.softmax(weight)
context_vector = torch.bmm(weight, h_src)
# |context_vector| = (batch_size, 1, hidden_size)
return context_vector
