Skip to content

Instantly share code, notes, and snippets.

@Wizaron
Last active August 6, 2018 21:06
Show Gist options
  • Select an option

  • Save Wizaron/05003207136ccd9374d94e2a476d3ac3 to your computer and use it in GitHub Desktop.

Select an option

Save Wizaron/05003207136ccd9374d94e2a476d3ac3 to your computer and use it in GitHub Desktop.

Revisions

  1. Wizaron revised this gist Aug 6, 2018. No changes.
  2. Wizaron revised this gist Aug 6, 2018. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions renet_sru.py
    Original file line number Diff line number Diff line change
    @@ -36,6 +36,7 @@ class ReNetSRU(nn.Module):
    >>> input = torch.randn(3, 10, 64, 64).to(device)
    >>> output = rsru(input)
    """

    def __init__(self, in_channels):
    super(ReNetSRU, self).__init__()

  3. Wizaron created this gist Aug 6, 2018.
    104 changes: 104 additions & 0 deletions renet_sru.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,104 @@
    import torch
    from torch import nn
    from torch.nn import functional as F


    class ReNetSRU(nn.Module):

    r"""ReNet with Simple Recurrent Unit
    Simple Recurrent Unit is defined in 'TRAINING RNNS AS FAST AS CNNS'
    (https://arxiv.org/pdf/1709.02755.pdf).
    ReNet is defined in 'ReNet: A Recurrent Neural Network Based Alternative
    to Convolutional Networks' (https://arxiv.org/pdf/1505.00393.pdf).
    This module implements ReNet that scans a feature map using SRUs.
    Args:
    in_channels (int) : Number of channels in the input feature map.
    Shape:
    - Input: `(N, C_{in}, H_{in}, W_{in})`
    - Output: `(N, C_{in}, H_{in}, W_{in})`
    Examples:
    >>> rsru = ReNetSRU(10)
    >>> input = torch.randn(3, 10, 64, 64)
    >>> output = rsru(input)
    >>> rsru = ReNetSRU(10).cuda()
    >>> input = torch.randn(3, 10, 64, 64).cuda()
    >>> output = rsru(input)
    >>> device = torch.device("cuda")
    >>> rsru = ReNetSRU(10).to(device)
    >>> input = torch.randn(3, 10, 64, 64).to(device)
    >>> output = rsru(input)
    """
    def __init__(self, in_channels):
    super(ReNetSRU, self).__init__()

    self.conv_hor = nn.Conv2d(in_channels, in_channels * 3, 1)
    self.conv_ver = nn.Conv2d(in_channels, in_channels * 3, 1)

    def get_cell_states(self, forget, x_hat, device):

    bs, nf, h, w = x_hat.size()

    c_indep = (1.0 - forget) * x_hat #bs, nf, h, w
    c_indep = c_indep.permute(0, 2, 3, 1) #bs, h, w, nf
    c_indep = c_indep.contiguous().view(bs * h, w, nf) #bs * h, w, nf

    forget = forget.permute(0, 2, 3, 1) #bs, h, w, nf
    forget = forget.contiguous().view(bs * h, w, nf) #bs * h, w, nf

    with torch.no_grad():
    c = Variable(torch.zeros(bs * h, nf)).to(device) #bs * h, nf

    c_out = []
    for step in range(w):
    c = c * forget[:, step] + c_indep[:, step]
    c_out.append(c)

    c_out = torch.stack(c_out, dim=1) #bs * h, w, nf
    c_out = c_out.view(bs, h, w, nf) #bs, h, w, nf
    c_out = c_out.permute(0, 3, 1, 2) #bs, nf, h, w

    return c_out

    def sru_forward(self, x, conv):

    nf = x.size(1)

    device = x.device

    x_hat, forget, reset = conv(x).split(nf, dim=1)
    forget = F.sigmoid(forget)
    reset = F.sigmoid(reset)

    c_out = self.get_cell_states(forget, x_hat, device)
    h_out = reset * F.relu(c_out) + (1.0 - reset) * x

    return h_out

    def forward(self, x):

    x = self.sru_forward(x, self.conv_hor) #bs, nf, h, w
    x = x.permute(0, 1, 3, 2) #bs, nf, w, h
    x = self.sru_forward(x, self.conv_ver) #bs, nf, w, h
    x = x.permute(0, 1, 3, 2) #bs, nf, h, w

    return x


    if __name__ == '__main__':
    from torch.autograd import Variable

    in_channels = 256
    device = torch.device('cpu')

    x = Variable(torch.rand(1, in_channels, 120, 256)).to(device)
    rsru = ReNetSRU(in_channels).to(device)
    out = rsru(x)
    print(x.size(), out.size())