Skip to content

Instantly share code, notes, and snippets.

@zhpmatrix
Created March 8, 2019 13:32
Show Gist options
  • Select an option

  • Save zhpmatrix/d72788e6de0754d9c36f9e71d03d57b6 to your computer and use it in GitHub Desktop.

Select an option

Save zhpmatrix/d72788e6de0754d9c36f9e71d03d57b6 to your computer and use it in GitHub Desktop.

Revisions

  1. zhpmatrix created this gist Mar 8, 2019.
    26 changes: 26 additions & 0 deletions rnn_multi_gpu
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,26 @@
    """
    https://discuss.pytorch.org/t/multi-layer-rnn-with-dataparallel/4450/2

    https://pytorch.org/docs/stable/nn.html

    """
    import torch

    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

    class Net(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
    super(Net, self).__init__()
    self.gru = torch.nn.GRU(input_size,hidden_size, num_layers=2, batch_first=False)
    for p in self.gru.parameters():
    torch.nn.init.normal_(p)
    def forward(self, input_, h0):
    output, ht = self.gru(input_,h0)
    return output, ht

    if __name__ == '__main__':
    model = torch.nn.DataParallel(Net(10,200), device_ids = [0,1], dim=1).cuda()
    input_ = torch.randn(5,3,10)
    h0 = torch.randn(2,3,200)
    output,hn = model(input_,h0)