""" https://discuss.pytorch.org/t/multi-layer-rnn-with-dataparallel/4450/2 https://pytorch.org/docs/stable/nn.html """ import torch import os os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' class Net(torch.nn.Module): def __init__(self, input_size, hidden_size): super(Net, self).__init__() self.gru = torch.nn.GRU(input_size,hidden_size, num_layers=2, batch_first=False) for p in self.gru.parameters(): torch.nn.init.normal_(p) def forward(self, input_, h0): output, ht = self.gru(input_,h0) return output, ht if __name__ == '__main__': model = torch.nn.DataParallel(Net(10,200), device_ids = [0,1], dim=1).cuda() input_ = torch.randn(5,3,10) h0 = torch.randn(2,3,200) output,hn = model(input_,h0)