Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save andyfcx/177287e94dd3cc232f3569265b1c30c9 to your computer and use it in GitHub Desktop.

Select an option

Save andyfcx/177287e94dd3cc232f3569265b1c30c9 to your computer and use it in GitHub Desktop.

Revisions

  1. @wassname wassname revised this gist Nov 1, 2016. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion keras_attention_wrapper.py
    Original file line number Diff line number Diff line change
    @@ -6,7 +6,7 @@
    date: 20161101
    author: wassname
    url:
    url: https://gist.github.com/wassname/5292f95000e409e239b9dc973295327a
    """

    from keras import backend as K
  2. @wassname wassname created this gist Nov 1, 2016.
    228 changes: 228 additions & 0 deletions keras_attention_wrapper.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,228 @@
    """
    A keras attention layer that wraps RNN layers.
    Based on tensorflows [attention_decoder](https://github.com/tensorflow/tensorflow/blob/c8a45a8e236776bed1d14fd71f3b6755bd63cc58/tensorflow/python/ops/seq2seq.py#L506)
    and [Grammar as a Foreign Language](https://arxiv.org/abs/1412.7449).
    date: 20161101
    author: wassname
    url:
    """

    from keras import backend as K
    from keras.engine import InputSpec
    from keras.layers import LSTM, activations, Wrapper, Recurrent

    class Attention(Wrapper):
    """
    This wrapper will provide an attention layer to a recurrent layer.
    # Arguments:
    layer: `Recurrent` instance with consume_less='gpu' or 'mem'
    # Examples:
    ```python
    model = Sequential()
    model.add(LSTM(10, return_sequences=True), batch_input_shape=(4, 5, 10))
    model.add(TFAttentionRNNWrapper(LSTM(10, return_sequences=True, consume_less='gpu')))
    model.add(Dense(5))
    model.add(Activation('softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
    ```
    # References
    - [Grammar as a Foreign Language](https://arxiv.org/abs/1412.7449)
    """
    def __init__(self, layer, **kwargs):
    assert isinstance(layer, Recurrent)
    if layer.get_config()['consume_less']=='cpu':
    raise Exception("AttentionLSTMWrapper doesn't support RNN's with consume_less='cpu'")
    self.supports_masking = True
    super(Attention, self).__init__(layer, **kwargs)

    def build(self, input_shape):
    assert len(input_shape) >= 3
    self.input_spec = [InputSpec(shape=input_shape)]
    nb_samples, nb_time, input_dim = input_shape

    if not self.layer.built:
    self.layer.build(input_shape)
    self.layer.built = True

    super(Attention, self).build()

    self.W1 = self.layer.init((input_dim, input_dim, 1, 1), name='{}_W1'.format(self.name))
    self.W2 = self.layer.init((self.layer.output_dim, input_dim), name='{}_W2'.format(self.name))
    self.b2 = K.zeros((input_dim,), name='{}_b2'.format(self.name))
    self.W3 = self.layer.init((input_dim*2, input_dim), name='{}_W3'.format(self.name))
    self.b3 = K.zeros((input_dim,), name='{}_b3'.format(self.name))
    self.V = self.layer.init((input_dim,), name='{}_V'.format(self.name))

    self.trainable_weights = [self.W1, self.W2, self.W3, self.V, self.b2, self.b3]

    def get_output_shape_for(self, input_shape):
    return self.layer.get_output_shape_for(input_shape)

    def step(self, x, states):
    # This is based on [tensorflows implementation](https://github.com/tensorflow/tensorflow/blob/c8a45a8e236776bed1d14fd71f3b6755bd63cc58/tensorflow/python/ops/seq2seq.py#L506).
    # First, we calculate new attention masks:
    # attn = softmax(V^T * tanh(W2 * X +b2 + W1 * h))
    # and we make the input as a concatenation of the input and weighted inputs which is then
    # transformed back to the shape x of using W3
    # x = W3*(x+X*attn)+b3
    # Then, we run the cell on a combination of the input and previous attention masks:
    # h, state = cell(x, h).

    nb_samples, nb_time, input_dim = self.input_spec[0].shape
    h = states[0]
    X = states[-1]
    xW1 = states[-2]

    Xr = K.reshape(X,(-1,nb_time,1,input_dim))
    hW2 = K.dot(h,self.W2)+self.b2
    hW2 = K.reshape(hW2,(-1,1,1,input_dim))
    u = K.tanh(xW1+hW2)
    a = K.sum(self.V*u,[2,3])
    a = K.softmax(a)
    a = K.reshape(a,(-1, nb_time, 1, 1))

    # Weight attention vector by attention
    Xa = K.sum(a*Xr,[1,2])
    Xa = K.reshape(Xa,(-1,input_dim))

    # Merge input and attention weighted inputs into one vector of the right size.
    x = K.dot(K.concatenate([x,Xa],1),self.W3)+self.b3

    h, new_states = self.layer.step(x, states)
    return h, new_states

    def get_constants(self, x):
    constants = self.layer.get_constants(x)

    # Calculate K.dot(x, W2) only once per sequence by making it a constant
    nb_samples, nb_time, input_dim = self.input_spec[0].shape
    Xr = K.reshape(x,(-1,nb_time,input_dim,1))
    Xrt = K.permute_dimensions(Xr, (0, 2, 1, 3))
    xW1t = K.conv2d(Xrt,self.W1,border_mode='same')
    xW1 = K.permute_dimensions(xW1t, (0, 2, 3, 1))
    constants.append(xW1)

    # we need to supply the full sequence of inputs to step (as the attention_vector)
    constants.append(x)

    return constants

    def call(self, x, mask=None):
    # input shape: (nb_samples, time (padded with zeros), input_dim)
    input_shape = self.input_spec[0].shape
    if K._BACKEND == 'tensorflow':
    if not input_shape[1]:
    raise Exception('When using TensorFlow, you should define '
    'explicitly the number of timesteps of '
    'your sequences.\n'
    'If your first layer is an Embedding, '
    'make sure to pass it an "input_length" '
    'argument. Otherwise, make sure '
    'the first layer has '
    'an "input_shape" or "batch_input_shape" '
    'argument, including the time axis. '
    'Found input shape at layer ' + self.name +
    ': ' + str(input_shape))

    if self.layer.stateful:
    initial_states = self.layer.states
    else:
    initial_states = self.layer.get_initial_states(x)
    constants = self.get_constants(x)
    preprocessed_input = self.layer.preprocess_input(x)


    last_output, outputs, states = K.rnn(self.step, preprocessed_input,
    initial_states,
    go_backwards=self.layer.go_backwards,
    mask=mask,
    constants=constants,
    unroll=self.layer.unroll,
    input_length=input_shape[1])
    if self.layer.stateful:
    self.updates = []
    for i in range(len(states)):
    self.updates.append((self.layer.states[i], states[i]))

    if self.layer.return_sequences:
    return outputs
    else:
    return last_output


    # test likes in https://github.com/fchollet/keras/blob/master/tests/keras/layers/test_wrappers.py
    import pytest
    import numpy as np
    from numpy.testing import assert_allclose
    from keras.utils.test_utils import keras_test
    from keras.layers import wrappers, Input, recurrent, InputLayer
    from keras.layers import core, convolutional, recurrent
    from keras.models import Sequential, Model, model_from_json

    nb_samples, timesteps, embedding_dim, output_dim = 2, 5, 3, 4
    embedding_num = 12
    x = np.random.random((nb_samples, timesteps, embedding_dim))
    y = np.random.random((nb_samples, timesteps, output_dim))

    # base line test with LSTM
    model = Sequential()
    model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim)))
    model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=True, consume_less='mem')))
    model.add(core.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')
    model.fit(x,y, nb_epoch=1, batch_size=nb_samples)


    # test stacked with all RNN layers and consume_less options
    model = Sequential()
    model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim)))

    # test supported consume_less options
    # model.add(Attention(recurrent.LSTM(embedding_dim, input_dim=embedding_dim,, consume_less='cpu' return_sequences=True))) # not supported
    model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, consume_less='gpu', return_sequences=True)))
    model.add(Attention(recurrent.LSTM(embedding_dim, input_dim=embedding_dim, consume_less='mem', return_sequences=True)))
    # test each other RNN type
    model.add(Attention(recurrent.GRU(embedding_dim, input_dim=embedding_dim, consume_less='mem', return_sequences=True)))
    model.add(Attention(recurrent.SimpleRNN(embedding_dim, input_dim=embedding_dim, consume_less='mem', return_sequences=True)))
    model.add(core.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')
    model.fit(x,y, nb_epoch=1, batch_size=nb_samples)

    # test with return_sequence = False
    model = Sequential()
    model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim)))
    model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=False, consume_less='mem')))
    model.add(core.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')
    model.fit(x,y[:,-1,:], nb_epoch=1, batch_size=nb_samples)

    # with bidirectional encoder
    model = Sequential()
    model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim)))
    model.add(wrappers.Bidirectional(recurrent.LSTM(embedding_dim, input_dim=embedding_dim, return_sequences=True)))
    model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=True, consume_less='mem')))
    model.add(core.Activation('relu'))
    model.compile(optimizer='rmsprop', loss='mse')
    model.fit(x,y, nb_epoch=1, batch_size=nb_samples)

    # test config
    model.get_config()

    # test to and from json
    model = model_from_json(model.to_json(),custom_objects=dict(Attention=Attention))
    model.summary()

    # test with functional API
    input = Input(batch_shape=(nb_samples, timesteps, embedding_dim))
    output = Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=True, consume_less='mem'))(input)
    model = Model(input, output)
    model.compile(optimizer='rmsprop', loss='mse')
    model.fit(x, y, nb_epoch=1, batch_size=nb_samples)