Skip to content

Instantly share code, notes, and snippets.

@Wizaron
Last active August 6, 2018 21:06
Show Gist options
  • Select an option

  • Save Wizaron/05003207136ccd9374d94e2a476d3ac3 to your computer and use it in GitHub Desktop.

Select an option

Save Wizaron/05003207136ccd9374d94e2a476d3ac3 to your computer and use it in GitHub Desktop.
ReNet with Simple Recurrent Unit (https://arxiv.org/pdf/1505.00393.pdf, https://arxiv.org/pdf/1709.02755.pdf) in PyTorch 0.4.0
import torch
from torch import nn
from torch.nn import functional as F
class ReNetSRU(nn.Module):
r"""ReNet with Simple Recurrent Unit
Simple Recurrent Unit is defined in 'TRAINING RNNS AS FAST AS CNNS'
(https://arxiv.org/pdf/1709.02755.pdf).
ReNet is defined in 'ReNet: A Recurrent Neural Network Based Alternative
to Convolutional Networks' (https://arxiv.org/pdf/1505.00393.pdf).
This module implements ReNet that scans a feature map using SRUs.
Args:
in_channels (int) : Number of channels in the input feature map.
Shape:
- Input: `(N, C_{in}, H_{in}, W_{in})`
- Output: `(N, C_{in}, H_{in}, W_{in})`
Examples:
>>> rsru = ReNetSRU(10)
>>> input = torch.randn(3, 10, 64, 64)
>>> output = rsru(input)
>>> rsru = ReNetSRU(10).cuda()
>>> input = torch.randn(3, 10, 64, 64).cuda()
>>> output = rsru(input)
>>> device = torch.device("cuda")
>>> rsru = ReNetSRU(10).to(device)
>>> input = torch.randn(3, 10, 64, 64).to(device)
>>> output = rsru(input)
"""
def __init__(self, in_channels):
super(ReNetSRU, self).__init__()
self.conv_hor = nn.Conv2d(in_channels, in_channels * 3, 1)
self.conv_ver = nn.Conv2d(in_channels, in_channels * 3, 1)
def get_cell_states(self, forget, x_hat, device):
bs, nf, h, w = x_hat.size()
c_indep = (1.0 - forget) * x_hat #bs, nf, h, w
c_indep = c_indep.permute(0, 2, 3, 1) #bs, h, w, nf
c_indep = c_indep.contiguous().view(bs * h, w, nf) #bs * h, w, nf
forget = forget.permute(0, 2, 3, 1) #bs, h, w, nf
forget = forget.contiguous().view(bs * h, w, nf) #bs * h, w, nf
with torch.no_grad():
c = Variable(torch.zeros(bs * h, nf)).to(device) #bs * h, nf
c_out = []
for step in range(w):
c = c * forget[:, step] + c_indep[:, step]
c_out.append(c)
c_out = torch.stack(c_out, dim=1) #bs * h, w, nf
c_out = c_out.view(bs, h, w, nf) #bs, h, w, nf
c_out = c_out.permute(0, 3, 1, 2) #bs, nf, h, w
return c_out
def sru_forward(self, x, conv):
nf = x.size(1)
device = x.device
x_hat, forget, reset = conv(x).split(nf, dim=1)
forget = F.sigmoid(forget)
reset = F.sigmoid(reset)
c_out = self.get_cell_states(forget, x_hat, device)
h_out = reset * F.relu(c_out) + (1.0 - reset) * x
return h_out
def forward(self, x):
x = self.sru_forward(x, self.conv_hor) #bs, nf, h, w
x = x.permute(0, 1, 3, 2) #bs, nf, w, h
x = self.sru_forward(x, self.conv_ver) #bs, nf, w, h
x = x.permute(0, 1, 3, 2) #bs, nf, h, w
return x
if __name__ == '__main__':
from torch.autograd import Variable
in_channels = 256
device = torch.device('cpu')
x = Variable(torch.rand(1, in_channels, 120, 256)).to(device)
rsru = ReNetSRU(in_channels).to(device)
out = rsru(x)
print(x.size(), out.size())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment