Skip to content

Instantly share code, notes, and snippets.

@karpathy
Created May 5, 2015 07:31
Show Gist options
  • Select an option

  • Save karpathy/7bae8033dcf5ca2630ba to your computer and use it in GitHub Desktop.

Select an option

Save karpathy/7bae8033dcf5ca2630ba to your computer and use it in GitHub Desktop.

Revisions

  1. karpathy revised this gist May 5, 2015. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion gistfile1.lua
    Original file line number Diff line number Diff line change
    @@ -1,5 +1,5 @@
    --[[
    Efficient LSTM in Torch using nngraph library this code was optimized
    Efficient LSTM in Torch using nngraph library. This code was optimized
    by Justin Johnson (@jcjohnson) based on the trick of batching up the
    LSTM GEMMs, as also seen in my efficient Python LSTM gist.
    --]]
  2. karpathy created this gist May 5, 2015.
    32 changes: 32 additions & 0 deletions gistfile1.lua
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,32 @@
    --[[
    Efficient LSTM in Torch using nngraph library this code was optimized
    by Justin Johnson (@jcjohnson) based on the trick of batching up the
    LSTM GEMMs, as also seen in my efficient Python LSTM gist.
    --]]

    function LSTM.fast_lstm(input_size, rnn_size)
    local x = nn.Identity()()
    local prev_c = nn.Identity()()
    local prev_h = nn.Identity()()

    local i2h = nn.Linear(input_size, 4 * rnn_size)(x)
    local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)
    local all_input_sums = nn.CAddTable()({i2h, h2h})

    local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums)
    sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk)
    local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk)
    local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk)
    local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk)

    local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums)
    in_transform = nn.Tanh()(in_transform)

    local next_c = nn.CAddTable()({
    nn.CMulTable()({forget_gate, prev_c}),
    nn.CMulTable()({in_gate, in_transform})
    })
    local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})

    return nn.gModule({x, prev_c, prev_h}, {next_c, next_h})
    end