Created
May 5, 2015 07:31
-
-
Save karpathy/7bae8033dcf5ca2630ba to your computer and use it in GitHub Desktop.
Revisions
-
karpathy revised this gist
May 5, 2015 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,5 +1,5 @@ --[[ Efficient LSTM in Torch using nngraph library. This code was optimized by Justin Johnson (@jcjohnson) based on the trick of batching up the LSTM GEMMs, as also seen in my efficient Python LSTM gist. --]] -
karpathy created this gist
May 5, 2015 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,32 @@ --[[ Efficient LSTM in Torch using nngraph library this code was optimized by Justin Johnson (@jcjohnson) based on the trick of batching up the LSTM GEMMs, as also seen in my efficient Python LSTM gist. --]] function LSTM.fast_lstm(input_size, rnn_size) local x = nn.Identity()() local prev_c = nn.Identity()() local prev_h = nn.Identity()() local i2h = nn.Linear(input_size, 4 * rnn_size)(x) local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h) local all_input_sums = nn.CAddTable()({i2h, h2h}) local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums) sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk) local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk) local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk) local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk) local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums) in_transform = nn.Tanh()(in_transform) local next_c = nn.CAddTable()({ nn.CMulTable()({forget_gate, prev_c}), nn.CMulTable()({in_gate, in_transform}) }) local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) return nn.gModule({x, prev_c, prev_h}, {next_c, next_h}) end