Skip to content

Instantly share code, notes, and snippets.

@VinACE
Created August 25, 2020 16:11
Show Gist options
  • Save VinACE/ca50bfa10252e98592376a61d9915f22 to your computer and use it in GitHub Desktop.
Save VinACE/ca50bfa10252e98592376a61d9915f22 to your computer and use it in GitHub Desktop.

Revisions

  1. VinACE created this gist Aug 25, 2020.
    473 changes: 473 additions & 0 deletions seq2seq_test_example
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,473 @@
    """
    # https://github.com/bentrevett/pytorch-seq2seq/issues/129
    # https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb
    """
    import torch
    import torch.nn as nn
    import torch.optim as optim

    import torchtext
    from torchtext.datasets import Multi30k
    from torchtext.data import Field, BucketIterator

    import matplotlib.pyplot as plt
    import matplotlib.ticker as ticker

    import spacy
    import numpy as np

    import random
    import math
    import time
    from IPython.core.debugger import set_trace #set_trace()


    ######## ENCODER PART #################################

    class Encoder(nn.Module):
    def __init__(self,
    input_dim,
    hid_dim,
    n_layers,
    n_heads,
    pf_dim,
    dropout,
    device,
    max_length = 100):
    super().__init__()

    self.device = device

    self.tok_embedding = nn.Embedding(input_dim, hid_dim)
    self.pos_embedding = nn.Embedding(max_length, hid_dim)

    self.layers = nn.ModuleList([EncoderLayer(hid_dim,
    n_heads,
    pf_dim,
    dropout,
    device)
    for _ in range(n_layers)])

    self.dropout = nn.Dropout(dropout)

    self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    def forward(self, src, src_mask):

    #src = [batch size, src len]
    #src_mask = [batch size, src len]

    batch_size = src.shape[0]
    src_len = src.shape[1]

    pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)

    #pos = [batch size, src len]

    src = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))

    #src = [batch size, src len, hid dim]

    for layer in self.layers:
    src = layer(src, src_mask)

    #src = [batch size, src len, hid dim]

    return src

    class EncoderLayer(nn.Module):
    def __init__(self,
    hid_dim,
    n_heads,
    pf_dim,
    dropout,
    device):
    super().__init__()

    self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
    self.ff_layer_norm = nn.LayerNorm(hid_dim)
    self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
    self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim,
    pf_dim,
    dropout)
    self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_mask):
    #src = [batch size, src len, hid dim]
    #src_mask = [batch size, src len]

    #self attention
    _src, _ = self.self_attention(src, src, src, src_mask)

    #dropout, residual connection and layer norm
    src = self.self_attn_layer_norm(src + self.dropout(_src))

    #src = [batch size, src len, hid dim]

    #positionwise feedforward
    _src = self.positionwise_feedforward(src)

    #dropout, residual and layer norm
    src = self.ff_layer_norm(src + self.dropout(_src))

    #src = [batch size, src len, hid dim]

    return src

    #### ATTENTION LAYER #################################################################

    class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
    super().__init__()

    assert hid_dim % n_heads == 0

    self.hid_dim = hid_dim
    self.n_heads = n_heads
    self.head_dim = hid_dim // n_heads

    self.fc_q = nn.Linear(hid_dim, hid_dim)
    self.fc_k = nn.Linear(hid_dim, hid_dim)
    self.fc_v = nn.Linear(hid_dim, hid_dim)

    self.fc_o = nn.Linear(hid_dim, hid_dim)

    self.dropout = nn.Dropout(dropout)

    self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)

    def forward(self, query, key, value, mask = None):

    batch_size = query.shape[0]

    #query = [batch size, query len, hid dim]
    #key = [batch size, key len, hid dim]
    #value = [batch size, value len, hid dim]

    Q = self.fc_q(query)
    K = self.fc_k(key)
    V = self.fc_v(value)

    #Q = [batch size, query len, hid dim]
    #K = [batch size, key len, hid dim]
    #V = [batch size, value len, hid dim]

    Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

    #Q = [batch size, n heads, query len, head dim]
    #K = [batch size, n heads, key len, head dim]
    #V = [batch size, n heads, value len, head dim]

    energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale

    #energy = [batch size, n heads, query len, key len]

    if mask is not None:
    energy = energy.masked_fill(mask == 0, -1e10)

    attention = torch.softmax(energy, dim = -1)

    #attention = [batch size, n heads, query len, key len]

    x = torch.matmul(self.dropout(attention), V)

    #x = [batch size, n heads, query len, head dim]

    x = x.permute(0, 2, 1, 3).contiguous()

    #x = [batch size, query len, n heads, head dim]

    x = x.view(batch_size, -1, self.hid_dim)

    #x = [batch size, query len, hid dim]

    x = self.fc_o(x)

    #x = [batch size, query len, hid dim]

    return x, attention


    class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
    super().__init__()

    self.fc_1 = nn.Linear(hid_dim, pf_dim)
    self.fc_2 = nn.Linear(pf_dim, hid_dim)

    self.dropout = nn.Dropout(dropout)

    def forward(self, x):

    #x = [batch size, seq len, hid dim]

    x = self.dropout(torch.relu(self.fc_1(x)))

    #x = [batch size, seq len, pf dim]

    x = self.fc_2(x)

    #x = [batch size, seq len, hid dim]

    return x

    ###### decoder part ###############

    class Decoder(nn.Module):
    def __init__(self,
    output_dim,
    hid_dim,
    n_layers,
    n_heads,
    pf_dim,
    dropout,
    device,
    max_length = 100):
    super().__init__()

    self.device = device

    self.tok_embedding = nn.Embedding(output_dim, hid_dim)
    self.pos_embedding = nn.Embedding(max_length, hid_dim)

    self.layers = nn.ModuleList([DecoderLayer(hid_dim,
    n_heads,
    pf_dim,
    dropout,
    device)
    for _ in range(n_layers)])

    self.fc_out = nn.Linear(hid_dim, output_dim)

    self.dropout = nn.Dropout(dropout)

    self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    def forward(self, trg, enc_src, trg_mask, src_mask):

    #trg = [batch size, trg len]
    #enc_src = [batch size, src len, hid dim]
    #trg_mask = [batch size, trg len]
    #src_mask = [batch size, src len]

    batch_size = trg.shape[0]
    trg_len = trg.shape[1]

    pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)

    #pos = [batch size, trg len]

    trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))

    #trg = [batch size, trg len, hid dim]

    for layer in self.layers:
    trg, attention = layer(trg, enc_src, trg_mask, src_mask)

    #trg = [batch size, trg len, hid dim]
    #attention = [batch size, n heads, trg len, src len]

    output = self.fc_out(trg)

    #output = [batch size, trg len, output dim]

    return output, attention


    ############# decoder Layer #################################

    class DecoderLayer(nn.Module):
    def __init__(self,
    hid_dim,
    n_heads,
    pf_dim,
    dropout,
    device):
    super().__init__()

    self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
    self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
    self.ff_layer_norm = nn.LayerNorm(hid_dim)
    self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
    self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device)
    self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim,
    pf_dim,
    dropout)
    self.dropout = nn.Dropout(dropout)

    def forward(self, trg, enc_src, trg_mask, src_mask):

    #trg = [batch size, trg len, hid dim]
    #enc_src = [batch size, src len, hid dim]
    #trg_mask = [batch size, trg len]
    #src_mask = [batch size, src len]

    #self attention
    _trg, _ = self.self_attention(trg, trg, trg, trg_mask)

    #dropout, residual connection and layer norm
    trg = self.self_attn_layer_norm(trg + self.dropout(_trg))

    #trg = [batch size, trg len, hid dim]

    #encoder attention
    _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)

    #dropout, residual connection and layer norm
    trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))

    #trg = [batch size, trg len, hid dim]

    #positionwise feedforward
    _trg = self.positionwise_feedforward(trg)

    #dropout, residual and layer norm
    trg = self.ff_layer_norm(trg + self.dropout(_trg))

    #trg = [batch size, trg len, hid dim]
    #attention = [batch size, n heads, trg len, src len]

    return trg, attention

    class Seq2Seq(nn.Module):
    def __init__(self,
    encoder,
    decoder,
    src_pad_idx,
    trg_pad_idx,
    device):
    super().__init__()

    self.encoder = encoder
    self.decoder = decoder
    self.src_pad_idx = src_pad_idx
    self.trg_pad_idx = trg_pad_idx
    self.device = device

    def make_src_mask(self, src):
    #src = [batch size, src len]

    src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

    #src_mask = [batch size, 1, 1, src len]

    return src_mask

    def make_trg_mask(self, trg):

    #trg = [batch size, trg len]

    trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)

    #trg_pad_mask = [batch size, 1, 1, trg len]

    trg_len = trg.shape[1]

    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()

    #trg_sub_mask = [trg len, trg len]

    trg_mask = trg_pad_mask & trg_sub_mask

    #trg_mask = [batch size, 1, trg len, trg len]

    return trg_mask

    def forward(self, src, trg):

    #src = [batch size, src len]
    #trg = [batch size, trg len]
    src_mask = self.make_src_mask(src)
    trg_mask = self.make_trg_mask(trg)

    #src_mask = [batch size, 1, 1, src len]
    #trg_mask = [batch size, 1, trg len, trg len]

    enc_src = self.encoder(src, src_mask)

    #enc_src = [batch size, src len, hid dim]

    output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)

    #output = [batch size, trg len, output dim]
    #attention = [batch size, n heads, trg len, src len]

    return output, attention


    # if __name__ == "__main__":
    # set_trace()
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
    # device
    # )
    # trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

    # src_pad_idx = 0
    # trg_pad_idx = 0
    # src_vocab_size = 10
    # trg_vocab_size = 10
    # model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx).to(
    # device
    # )
    # out, attention = model(x, trg[:, :-1])
    # print(out.shape)


    # INPUT_DIM = len(SRC.vocab)
    # OUTPUT_DIM = len(TRG.vocab)

    if __name__ == "__main__":
    pass

    if __name__ == "__main__":
    # set_trace()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    INPUT_DIM = 10
    OUTPUT_DIM = 10
    HID_DIM = 256
    ENC_LAYERS = 3
    DEC_LAYERS = 3
    ENC_HEADS = 8
    DEC_HEADS = 8
    ENC_PF_DIM = 512
    DEC_PF_DIM = 512
    ENC_DROPOUT = 0.1
    DEC_DROPOUT = 0.1

    SRC_PAD_IDX = 0
    TRG_PAD_IDX = 0

    enc = Encoder(INPUT_DIM,
    HID_DIM,
    ENC_LAYERS,
    ENC_HEADS,
    ENC_PF_DIM,
    ENC_DROPOUT,
    device)

    dec = Decoder(OUTPUT_DIM,
    HID_DIM,
    DEC_LAYERS,
    DEC_HEADS,
    DEC_PF_DIM,
    DEC_DROPOUT,
    device)

    src_vocab_size = 10
    trg_vocab_size = 10

    x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2, 0]]).to(
    device, dtype=torch.int64
    )
    trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0,0,0], [1, 5, 6, 2, 4, 7, 6, 2,0,0]]).to(device, dtype=torch.int64)


    model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(
    device
    )
    out, attention = model(x, trg[:, :-1])
    print(out.shape)
    print(attention.shape)