Skip to content

Instantly share code, notes, and snippets.

@pvva
Created November 3, 2019 17:42
Show Gist options
  • Save pvva/331444eaded2b939ce4f44ae8e4780b2 to your computer and use it in GitHub Desktop.
Save pvva/331444eaded2b939ce4f44ae8e4780b2 to your computer and use it in GitHub Desktop.
rnn2
class CharRnn(nn.Module):
def __init__(self, vocab_size, n_fac, n_hidden, batch_size, layers=2):
super().__init__()
self.e = nn.Embedding(vocab_size, n_fac)
self.rnn = nn.LSTM(n_fac, n_hidden, layers, dropout=0.1)
self.l_out = nn.Linear(n_hidden, vocab_size)
self.n_hidden = n_hidden
self.layers = layers
self.init_hidden_state(batch_size)
def init_hidden_state(self, batch_size):
self.h = (
torch.zeros(self.layers, batch_size, self.n_hidden).cuda(),
torch.zeros(self.layers, batch_size, self.n_hidden).cuda(),
)
def forward(self, inp):
inp = self.e(inp)
b_size = inp[0].size(0)
if self.h[0].size(1) != b_size:
self.init_hidden_state(b_size)
outp, h = self.rnn(inp, self.h)
self.h = detach_from_history(h)
return F.log_softmax(self.l_out(outp), dim=-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment