Created
July 19, 2018 01:01
-
-
Save lixinsu/31c6b756a0a4f636ba79aa2c11bbabf3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class MultiheadAttention(nn.Module): | |
| """Multi-headed attention. | |
| See "Attention Is All You Need" for more details. | |
| """ | |
| def __init__(self, embed_dim, num_heads, dropout=0., bias=True): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.head_dim = embed_dim // num_heads | |
| assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" | |
| self.scaling = self.head_dim**-0.5 | |
| self._mask = None | |
| self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim)) | |
| if bias: | |
| self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim)) | |
| else: | |
| self.register_parameter('in_proj_bias', None) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.init.xavier_uniform_(self.in_proj_weight) | |
| nn.init.xavier_uniform_(self.out_proj.weight) | |
| if self.in_proj_bias is not None: | |
| nn.init.constant_(self.in_proj_bias, 0.) | |
| nn.init.constant_(self.out_proj.bias, 0.) | |
| def forward(self, query, key, value, mask_future_timesteps=False, | |
| key_padding_mask=None, incremental_state=None, | |
| need_weights=True, static_kv=False): | |
| """Input shape: Time x Batch x Channel | |
| Self-attention can be implemented by passing in the same arguments for | |
| query, key and value. Future timesteps can be masked with the | |
| `mask_future_timesteps` argument. Padding elements can be excluded from | |
| the key by passing a binary ByteTensor (`key_padding_mask`) with shape: | |
| batch x src_len, where padding elements are indicated by 1s. | |
| """ | |
| qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() | |
| kv_same = key.data_ptr() == value.data_ptr() | |
| tgt_len, bsz, embed_dim = query.size() | |
| assert embed_dim == self.embed_dim | |
| assert list(query.size()) == [tgt_len, bsz, embed_dim] | |
| assert key.size() == value.size() | |
| if incremental_state is not None: | |
| saved_state = self._get_input_buffer(incremental_state) | |
| if 'prev_key' in saved_state: | |
| # previous time steps are cached - no need to recompute | |
| # key and value if they are static | |
| if static_kv: | |
| assert kv_same and not qkv_same | |
| key = value = None | |
| else: | |
| saved_state = None | |
| if qkv_same: | |
| # self-attention | |
| q, k, v = self.in_proj_qkv(query) | |
| elif kv_same: | |
| # encoder-decoder attention | |
| q = self.in_proj_q(query) | |
| if key is None: | |
| assert value is None | |
| # this will allow us to concat it with previous value and get | |
| # just get the previous value | |
| k = v = q.new(0) | |
| else: | |
| k, v = self.in_proj_kv(key) | |
| else: | |
| q = self.in_proj_q(query) | |
| k = self.in_proj_k(key) | |
| v = self.in_proj_v(value) | |
| q *= self.scaling | |
| if saved_state is not None: | |
| if 'prev_key' in saved_state: | |
| k = torch.cat((saved_state['prev_key'], k), dim=0) | |
| if 'prev_value' in saved_state: | |
| v = torch.cat((saved_state['prev_value'], v), dim=0) | |
| saved_state['prev_key'] = k | |
| saved_state['prev_value'] = v | |
| self._set_input_buffer(incremental_state, saved_state) | |
| src_len = k.size(0) | |
| if key_padding_mask is not None: | |
| assert key_padding_mask.size(0) == bsz | |
| assert key_padding_mask.size(1) == src_len | |
| q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) | |
| k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) | |
| v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) | |
| attn_weights = torch.bmm(q, k.transpose(1, 2)) | |
| assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] | |
| # only apply masking at training time (when incremental state is None) | |
| if mask_future_timesteps and incremental_state is None: | |
| assert query.size() == key.size(), \ | |
| 'mask_future_timesteps only applies to self-attention' | |
| attn_weights += self.buffered_mask(attn_weights).unsqueeze(0) | |
| if key_padding_mask is not None: | |
| # don't attend to padding symbols | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | |
| attn_weights = attn_weights.float().masked_fill( | |
| key_padding_mask.unsqueeze(1).unsqueeze(2), | |
| float('-inf'), | |
| ).type_as(attn_weights) # FP16 support: cast to float and back | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) | |
| attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) | |
| attn = torch.bmm(attn_weights, v) | |
| assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] | |
| attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) | |
| attn = self.out_proj(attn) | |
| if need_weights: | |
| # average attention weights over heads | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | |
| attn_weights = attn_weights.sum(dim=1) / self.num_heads | |
| else: | |
| attn_weights = None | |
| return attn, attn_weights | |
| def in_proj_qkv(self, query): | |
| return self._in_proj(query).chunk(3, dim=-1) | |
| def in_proj_kv(self, key): | |
| return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) | |
| def in_proj_q(self, query): | |
| return self._in_proj(query, end=self.embed_dim) | |
| def in_proj_k(self, key): | |
| return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim) | |
| def in_proj_v(self, value): | |
| return self._in_proj(value, start=2*self.embed_dim) | |
| def _in_proj(self, input, start=None, end=None): | |
| weight = self.in_proj_weight | |
| bias = self.in_proj_bias | |
| if end is not None: | |
| weight = weight[:end, :] | |
| if bias is not None: | |
| bias = bias[:end] | |
| if start is not None: | |
| weight = weight[start:, :] | |
| if bias is not None: | |
| bias = bias[start:] | |
| return F.linear(input, weight, bias) | |
| def buffered_mask(self, tensor): | |
| dim = tensor.size(-1) | |
| if self._mask is None: | |
| self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) | |
| if self._mask.size(0) < dim: | |
| self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1) | |
| return self._mask[:dim, :dim] | |
| def reorder_incremental_state(self, incremental_state, new_order): | |
| """Reorder buffered internal state (for incremental generation).""" | |
| input_buffer = self._get_input_buffer(incremental_state) | |
| if input_buffer is not None: | |
| for k in input_buffer.keys(): | |
| input_buffer[k] = input_buffer[k].index_select(1, new_order) | |
| self._set_input_buffer(incremental_state, input_buffer) | |
| def _get_input_buffer(self, incremental_state): | |
| return utils.get_incremental_state( | |
| self, | |
| incremental_state, | |
| 'attn_state', | |
| ) or {} | |
| def _set_input_buffer(self, incremental_state, buffer): | |
| utils.set_incremental_state( | |
| self, | |
| incremental_state, | |
| 'attn_state', | |
| buffer, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment