Skip to content

Instantly share code, notes, and snippets.

@lixinsu
Created July 19, 2018 01:01
Show Gist options
  • Select an option

  • Save lixinsu/31c6b756a0a4f636ba79aa2c11bbabf3 to your computer and use it in GitHub Desktop.

Select an option

Save lixinsu/31c6b756a0a4f636ba79aa2c11bbabf3 to your computer and use it in GitHub Desktop.

Revisions

  1. lixinsu created this gist Jul 19, 2018.
    189 changes: 189 additions & 0 deletions multihead_attention.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,189 @@
    class MultiheadAttention(nn.Module):
    """Multi-headed attention.
    See "Attention Is All You Need" for more details.
    """
    def __init__(self, embed_dim, num_heads, dropout=0., bias=True):
    super().__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.dropout = dropout
    self.head_dim = embed_dim // num_heads
    assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
    self.scaling = self.head_dim**-0.5
    self._mask = None

    self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
    if bias:
    self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
    else:
    self.register_parameter('in_proj_bias', None)
    self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

    self.reset_parameters()

    def reset_parameters(self):
    nn.init.xavier_uniform_(self.in_proj_weight)
    nn.init.xavier_uniform_(self.out_proj.weight)
    if self.in_proj_bias is not None:
    nn.init.constant_(self.in_proj_bias, 0.)
    nn.init.constant_(self.out_proj.bias, 0.)

    def forward(self, query, key, value, mask_future_timesteps=False,
    key_padding_mask=None, incremental_state=None,
    need_weights=True, static_kv=False):
    """Input shape: Time x Batch x Channel
    Self-attention can be implemented by passing in the same arguments for
    query, key and value. Future timesteps can be masked with the
    `mask_future_timesteps` argument. Padding elements can be excluded from
    the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
    batch x src_len, where padding elements are indicated by 1s.
    """

    qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
    kv_same = key.data_ptr() == value.data_ptr()

    tgt_len, bsz, embed_dim = query.size()
    assert embed_dim == self.embed_dim
    assert list(query.size()) == [tgt_len, bsz, embed_dim]
    assert key.size() == value.size()

    if incremental_state is not None:
    saved_state = self._get_input_buffer(incremental_state)
    if 'prev_key' in saved_state:
    # previous time steps are cached - no need to recompute
    # key and value if they are static
    if static_kv:
    assert kv_same and not qkv_same
    key = value = None
    else:
    saved_state = None

    if qkv_same:
    # self-attention
    q, k, v = self.in_proj_qkv(query)
    elif kv_same:
    # encoder-decoder attention
    q = self.in_proj_q(query)
    if key is None:
    assert value is None
    # this will allow us to concat it with previous value and get
    # just get the previous value
    k = v = q.new(0)
    else:
    k, v = self.in_proj_kv(key)
    else:
    q = self.in_proj_q(query)
    k = self.in_proj_k(key)
    v = self.in_proj_v(value)
    q *= self.scaling

    if saved_state is not None:
    if 'prev_key' in saved_state:
    k = torch.cat((saved_state['prev_key'], k), dim=0)
    if 'prev_value' in saved_state:
    v = torch.cat((saved_state['prev_value'], v), dim=0)
    saved_state['prev_key'] = k
    saved_state['prev_value'] = v
    self._set_input_buffer(incremental_state, saved_state)

    src_len = k.size(0)

    if key_padding_mask is not None:
    assert key_padding_mask.size(0) == bsz
    assert key_padding_mask.size(1) == src_len

    q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
    k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
    v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)

    attn_weights = torch.bmm(q, k.transpose(1, 2))
    assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

    # only apply masking at training time (when incremental state is None)
    if mask_future_timesteps and incremental_state is None:
    assert query.size() == key.size(), \
    'mask_future_timesteps only applies to self-attention'
    attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
    if key_padding_mask is not None:
    # don't attend to padding symbols
    attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
    attn_weights = attn_weights.float().masked_fill(
    key_padding_mask.unsqueeze(1).unsqueeze(2),
    float('-inf'),
    ).type_as(attn_weights) # FP16 support: cast to float and back
    attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
    attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
    attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)

    attn = torch.bmm(attn_weights, v)
    assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
    attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
    attn = self.out_proj(attn)

    if need_weights:
    # average attention weights over heads
    attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
    attn_weights = attn_weights.sum(dim=1) / self.num_heads
    else:
    attn_weights = None

    return attn, attn_weights

    def in_proj_qkv(self, query):
    return self._in_proj(query).chunk(3, dim=-1)

    def in_proj_kv(self, key):
    return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)

    def in_proj_q(self, query):
    return self._in_proj(query, end=self.embed_dim)

    def in_proj_k(self, key):
    return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim)

    def in_proj_v(self, value):
    return self._in_proj(value, start=2*self.embed_dim)

    def _in_proj(self, input, start=None, end=None):
    weight = self.in_proj_weight
    bias = self.in_proj_bias
    if end is not None:
    weight = weight[:end, :]
    if bias is not None:
    bias = bias[:end]
    if start is not None:
    weight = weight[start:, :]
    if bias is not None:
    bias = bias[start:]
    return F.linear(input, weight, bias)

    def buffered_mask(self, tensor):
    dim = tensor.size(-1)
    if self._mask is None:
    self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
    if self._mask.size(0) < dim:
    self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
    return self._mask[:dim, :dim]

    def reorder_incremental_state(self, incremental_state, new_order):
    """Reorder buffered internal state (for incremental generation)."""
    input_buffer = self._get_input_buffer(incremental_state)
    if input_buffer is not None:
    for k in input_buffer.keys():
    input_buffer[k] = input_buffer[k].index_select(1, new_order)
    self._set_input_buffer(incremental_state, input_buffer)

    def _get_input_buffer(self, incremental_state):
    return utils.get_incremental_state(
    self,
    incremental_state,
    'attn_state',
    ) or {}

    def _set_input_buffer(self, incremental_state, buffer):
    utils.set_incremental_state(
    self,
    incremental_state,
    'attn_state',
    buffer,
    )