Created
July 19, 2018 01:01
-
-
Save lixinsu/31c6b756a0a4f636ba79aa2c11bbabf3 to your computer and use it in GitHub Desktop.
Revisions
-
lixinsu created this gist
Jul 19, 2018 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,189 @@ class MultiheadAttention(nn.Module): """Multi-headed attention. See "Attention Is All You Need" for more details. """ def __init__(self, embed_dim, num_heads, dropout=0., bias=True): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" self.scaling = self.head_dim**-0.5 self._mask = None self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim)) if bias: self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim)) else: self.register_parameter('in_proj_bias', None) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.in_proj_weight) nn.init.xavier_uniform_(self.out_proj.weight) if self.in_proj_bias is not None: nn.init.constant_(self.in_proj_bias, 0.) nn.init.constant_(self.out_proj.bias, 0.) def forward(self, query, key, value, mask_future_timesteps=False, key_padding_mask=None, incremental_state=None, need_weights=True, static_kv=False): """Input shape: Time x Batch x Channel Self-attention can be implemented by passing in the same arguments for query, key and value. Future timesteps can be masked with the `mask_future_timesteps` argument. Padding elements can be excluded from the key by passing a binary ByteTensor (`key_padding_mask`) with shape: batch x src_len, where padding elements are indicated by 1s. """ qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() kv_same = key.data_ptr() == value.data_ptr() tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] assert key.size() == value.size() if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if 'prev_key' in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert kv_same and not qkv_same key = value = None else: saved_state = None if qkv_same: # self-attention q, k, v = self.in_proj_qkv(query) elif kv_same: # encoder-decoder attention q = self.in_proj_q(query) if key is None: assert value is None # this will allow us to concat it with previous value and get # just get the previous value k = v = q.new(0) else: k, v = self.in_proj_kv(key) else: q = self.in_proj_q(query) k = self.in_proj_k(key) v = self.in_proj_v(value) q *= self.scaling if saved_state is not None: if 'prev_key' in saved_state: k = torch.cat((saved_state['prev_key'], k), dim=0) if 'prev_value' in saved_state: v = torch.cat((saved_state['prev_value'], v), dim=0) saved_state['prev_key'] = k saved_state['prev_value'] = v self._set_input_buffer(incremental_state, saved_state) src_len = k.size(0) if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) attn_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] # only apply masking at training time (when incremental state is None) if mask_future_timesteps and incremental_state is None: assert query.size() == key.size(), \ 'mask_future_timesteps only applies to self-attention' attn_weights += self.buffered_mask(attn_weights).unsqueeze(0) if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.float().masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ).type_as(attn_weights) # FP16 support: cast to float and back attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn = torch.bmm(attn_weights, v) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) if need_weights: # average attention weights over heads attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.sum(dim=1) / self.num_heads else: attn_weights = None return attn, attn_weights def in_proj_qkv(self, query): return self._in_proj(query).chunk(3, dim=-1) def in_proj_kv(self, key): return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) def in_proj_q(self, query): return self._in_proj(query, end=self.embed_dim) def in_proj_k(self, key): return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim) def in_proj_v(self, value): return self._in_proj(value, start=2*self.embed_dim) def _in_proj(self, input, start=None, end=None): weight = self.in_proj_weight bias = self.in_proj_bias if end is not None: weight = weight[:end, :] if bias is not None: bias = bias[:end] if start is not None: weight = weight[start:, :] if bias is not None: bias = bias[start:] return F.linear(input, weight, bias) def buffered_mask(self, tensor): dim = tensor.size(-1) if self._mask is None: self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) if self._mask.size(0) < dim: self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1) return self._mask[:dim, :dim] def reorder_incremental_state(self, incremental_state, new_order): """Reorder buffered internal state (for incremental generation).""" input_buffer = self._get_input_buffer(incremental_state) if input_buffer is not None: for k in input_buffer.keys(): input_buffer[k] = input_buffer[k].index_select(1, new_order) self._set_input_buffer(incremental_state, input_buffer) def _get_input_buffer(self, incremental_state): return utils.get_incremental_state( self, incremental_state, 'attn_state', ) or {} def _set_input_buffer(self, incremental_state, buffer): utils.set_incremental_state( self, incremental_state, 'attn_state', buffer, )