Skip to content

Instantly share code, notes, and snippets.

@pvva
Created November 3, 2019 17:41
Show Gist options
  • Save pvva/869a85ba1f9c62a8332a025830cd72ec to your computer and use it in GitHub Desktop.
Save pvva/869a85ba1f9c62a8332a025830cd72ec to your computer and use it in GitHub Desktop.
rnn1_2
def detach_from_history(h):
if type(h) == torch.Tensor:
return h.detach()
return tuple(detach_from_history(v) for v in h)
class CharRnn(nn.Module):
def __init__(self, vocab_size, n_fac, n_hidden, batch_size):
super().__init__()
self.e = nn.Embedding(vocab_size, n_fac)
self.rnn = nn.RNN(n_fac, n_hidden)
self.l_out = nn.Linear(n_hidden, vocab_size)
self.n_hidden = n_hidden
self.init_hidden_state(batch_size)
def init_hidden_state(self, batch_size):
self.h = torch.zeros(1, batch_size, self.n_hidden).cuda()
def forward(self, inp):
inp = self.e(inp)
b_size = inp[0].size(0)
if self.h[0].size(1) != b_size:
self.init_hidden_state(b_size)
outp, h = self.rnn(inp, self.h)
self.h = detach_from_history(h)
return F.log_softmax(self.l_out(outp[-1]), dim=-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment