Created
November 23, 2017 08:04
-
-
Save ajbrock/8f15ffc44db2080d8aeb5eb0f064a419 to your computer and use it in GitHub Desktop.
Revisions
-
ajbrock created this gist
Nov 23, 2017 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,182 @@ ## Wide ResNet with Shift and incorrect hyperparams. # Based on code by xternalz: https://github.com/xternalz/WideResNet-pytorch # WRN by Sergey Zagoruyko and Nikos Komodakis import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable as V import torch.optim as optim import numpy as np #torch.cat([torch.zeros(x.size(0),self.channels_per_group,1,x.size(2)).cuda() # We'll allocate any leftover channels to the center group class shift(nn.Module): def __init__(self, in_planes, kernel_size=3): super(shift, self).__init__() self.in_planes = in_planes self.kernel_size = kernel_size self.channels_per_group = self.in_planes // (self.kernel_size**2) # self.groups = self.in_planes // kernel_size # Leave the final group in place # We've actually reversed the tops+bottoms vs left+right (first spatial index being rows, second being columns). Oh well. def forward(self,x): # out = V(torch.zeros(x.size()).cuda()) x_pad = F.pad(x,(1,1,1,1)) # Alias for convenience cpg = self.channels_per_group # cat_layers = [torch.cat([V(torch.zeros(x.size(0),x.size(1),1,x.size(3)).cuda()), # x[:, i * cpg : (i + 1) * cpg, :-1, :]],2)] cat_layers =[] # Bottom shift, grab the Top element i = 0 cat_layers += [x_pad[:, i * cpg : (i + 1) * cpg, :-2, 1:-1]] # Top shift, grab the Bottom element i = 1 cat_layers += [x_pad[:, i * cpg : (i + 1) * cpg, 2:, 1:-1]] # Right shift, grab the left element i = 2 cat_layers += [x_pad[:, i * cpg : (i + 1) * cpg, 1:-1, :-2]] # Left shift, grab the right element i = 3 cat_layers += [x_pad[:, i * cpg : (i + 1) * cpg, 1:-1, 2:]] # Bottom Right shift, grab the Top left element i = 4 cat_layers += [x_pad[:, i * cpg : (i + 1) * cpg, :-2, :-2]] # Bottom Left shift, grab the Top right element i = 5 cat_layers += [x_pad[:, i * cpg : (i + 1) * cpg, :-2, 2:]] # Top Right shift, grab the Bottom Left element i = 6 cat_layers += [x_pad[:, i * cpg : (i + 1) * cpg, 2:, :-2]] # Top Left shift, grab the Bottom Right element i = 7 cat_layers += [x_pad[:, i * cpg : (i + 1) * cpg, 2:, 2:]] i = 8 cat_layers += [x_pad[:, i * cpg :, 1:-1, 1:-1]] return torch.cat(cat_layers,1) class BasicBlock(nn.Module): def __init__(self, in_planes, out_planes, stride, dropRate,E=9): super(BasicBlock, self).__init__() self.bn1 = nn.BatchNorm2d(in_planes) self.relu1 = nn.ReLU(inplace=True) self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) self.conv2 = shift(out_planes) self.bn2 = nn.BatchNorm2d(out_planes) self.relu2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) self.droprate = dropRate self.equalInOut = (in_planes == out_planes) self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) or None def forward(self, x): if not self.equalInOut: x = self.relu1(self.bn1(x)) else: out = self.relu1(self.bn1(x)) out = self.relu2(self.bn2(self.conv2(self.conv1(out if self.equalInOut else x)))) if self.droprate > 0: out = F.dropout(out, p=self.droprate, training=self.training) out = self.conv3(out) out = torch.add(x if self.equalInOut else self.convShortcut(x), out) # print(x.size(),out.size()) return out # note: we call it DenseNet for simple compatibility with the training code. # similar we call it growthRate instead of widen_factor class Network(nn.Module): def __init__(self, widen_factor, depth, nClasses, epochs, dropRate=0.0): super(Network, self).__init__() self.epochs = epochs nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] assert((depth - 4) % 6 == 0) n = int((depth - 4) / 6) block = BasicBlock # 1st conv before any network block self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False) # 1st block self.block1 = self._make_layer(n, nChannels[0], nChannels[1], block, 1, dropRate) # 2nd block self.block2 = self._make_layer(n, nChannels[1], nChannels[2], block, 2, dropRate) # 3rd block self.block3 = self._make_layer(n, nChannels[2], nChannels[3], block, 2, dropRate) # global average pooling and classifier self.bn1 = nn.BatchNorm2d(nChannels[3]) self.relu = nn.ReLU(inplace=True) self.fc = nn.Linear(nChannels[3], nClasses) self.nChannels = nChannels[3] for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.bias.data.zero_() # Optimizer self.lr = 1e-1 self.optim = optim.SGD(params=self.parameters(),lr=self.lr, nesterov=True,momentum=0.9, weight_decay=1e-4) # Iteration Counter self.j = 0 # A simple dummy variable that indicates we are using an iteration-wise # annealing scheme as opposed to epoch-wise. self.lr_sched = {'itr':0} def _make_layer(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): layers = [] for i in range(nb_layers): layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) return nn.Sequential(*layers) def update_lr(self, max_j): for param_group in self.optim.param_groups: param_group['lr'] = (0.5 * self.lr) * (1 + np.cos(np.pi * self.j / max_j)) self.j += 1 def forward(self, x): out = self.conv1(x) out = self.block1(out) out = self.block2(out) out = self.block3(out) out = self.relu(self.bn1(out)) out = F.avg_pool2d(out, (out.size(2),out.size(3))) out = out.view(-1, self.nChannels) return F.log_softmax(self.fc(out))