-
-
Save rtqichen/b22a9c6bfc4f36e605a7b3ac1ab4122f to your computer and use it in GitHub Desktop.
| ## Weight norm is now added to pytorch as a pre-hook, so use that instead :) | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import Parameter | |
| from functools import wraps | |
| class WeightNorm(nn.Module): | |
| append_g = '_g' | |
| append_v = '_v' | |
| def __init__(self, module, weights): | |
| super(WeightNorm, self).__init__() | |
| self.module = module | |
| self.weights = weights | |
| self._reset() | |
| def _reset(self): | |
| for name_w in self.weights: | |
| w = getattr(self.module, name_w) | |
| # construct g,v such that w = g/||v|| * v | |
| g = torch.norm(w) | |
| v = w/g.expand_as(w) | |
| g = Parameter(g.data) | |
| v = Parameter(v.data) | |
| name_g = name_w + self.append_g | |
| name_v = name_w + self.append_v | |
| # remove w from parameter list | |
| del self.module._parameters[name_w] | |
| # add g and v as new parameters | |
| self.module.register_parameter(name_g, g) | |
| self.module.register_parameter(name_v, v) | |
| def _setweights(self): | |
| for name_w in self.weights: | |
| name_g = name_w + self.append_g | |
| name_v = name_w + self.append_v | |
| g = getattr(self.module, name_g) | |
| v = getattr(self.module, name_v) | |
| w = v*(g/torch.norm(v)).expand_as(v) | |
| setattr(self.module, name_w, w) | |
| def forward(self, *args): | |
| self._setweights() | |
| return self.module.forward(*args) | |
| ############################################################## | |
| ## An older version using a python decorator but might be buggy. | |
| ## Does not work when the module is replicated (e.g. nn.DataParallel) | |
| def _decorate(forward, module, name, name_g, name_v): | |
| @wraps(forward) | |
| def decorated_forward(*args, **kwargs): | |
| g = module.__getattr__(name_g) | |
| v = module.__getattr__(name_v) | |
| w = v*(g/torch.norm(v)).expand_as(v) | |
| module.__setattr__(name, w) | |
| return forward(*args, **kwargs) | |
| return decorated_forward | |
| def weight_norm(module, name): | |
| param = module.__getattr__(name) | |
| # construct g,v such that w = g/||v|| * v | |
| g = torch.norm(param) | |
| v = param/g.expand_as(param) | |
| g = Parameter(g.data) | |
| v = Parameter(v.data) | |
| name_g = name + '_g' | |
| name_v = name + '_v' | |
| # remove w from parameter list | |
| del module._parameters[name] | |
| # add g and v as new parameters | |
| module.register_parameter(name_g, g) | |
| module.register_parameter(name_v, v) | |
| # construct w every time before forward is called | |
| module.forward = _decorate(module.forward, module, name, name_g, name_v) | |
| return module |
| import torch | |
| import torch.nn as nn | |
| from pytorch_weight_norm import WeightNorm | |
| x = torch.autograd.Variable(torch.randn(5,10,30,30)) | |
| m = nn.ConvTranspose2d(10,20,3) | |
| y = m(x) | |
| print(m._parameters.keys()) | |
| # odict_keys(['weight', 'bias']) | |
| m = WeightNorm(m, ['weight']) | |
| y_wn = m(x) | |
| print(m.module._parameters.keys()) | |
| # odict_keys(['bias', 'weight_g', 'weight_v']) | |
| print(torch.norm(y-y_wn).data[0]) | |
| # 1.3324766769073904e-05 (not important to get this smaller) | |
| ## can also use within sequential | |
| ## and is also stackable | |
| net = nn.Sequential( | |
| WeightNorm(nn.Linear(30,10), ['weight']), | |
| nn.ReLU(), | |
| WeightNorm(nn.Linear(10,20), ['weight', 'bias']), | |
| ) |
@rtqichen could not find the data dependent init. I thought it was important to the weight norm. Isn't it?
How to incorporate the Pytorch 0.2.0 support of Weight Normalization in new RNN projects?
http://pytorch.org/docs/master/nn.html#torch.nn.utils.weight_norm
hello, thanks for sharing this elegant implementation. Where could I find the newer updated version?
Thanks!
@rtqichen Thanks for contribution for this code.
@ all @greaber @Smerity @ypxie @ hanzhanggit
Hi everyone who read this post. I have some questions regarding to weight_norm. It would be great if you can help.
I tried to implement the weight_norm for each convolution and linear layer (check the code here https://github.com/xwuaustin/weight_norm/blob/master/cifar10_tutorial_weightNorm.py ). However, the training loss in CIFAR-10 seems no difference to the original setting (see the picture below) at the first 10 epochs (6 iterations equal to 1 epoch).

Now questions:
1. Is there something wrong with the code I modified? I used the code from cifar10_tutorial in pytorch. All I did is to add the wieghtNorm at each layer.
import torch.nn.utils.weight_norm as weightNorm
class Net(nn.Module):
def init(self):
super(Net, self).init()
### we use weight normalization after each convolutions and linear transfrom
self.conv1 = weightNorm(nn.Conv2d(3, 6, 5),name = "weight")
#print (self.conv1._parameters.keys())
self.pool = nn.MaxPool2d(2, 2)
self.conv2 =weightNorm(nn.Conv2d(6, 16, 5),name = "weight")
self.fc1 = weightNorm(nn.Linear(16 * 5 * 5, 120),name = "weight")
self.fc2 = weightNorm(nn.Linear(120, 84),name = "weight")
self.fc3 = weightNorm(nn.Linear(84, 10),name = "weight")
2 Is the update of the weights and bias, namely 'weight_g', 'weight_v', using the formulation:
3. Can we do the initialization as the paper suggested?
Thanks. Looking for your responds. :)
Excellent work to solve weight_norm(...) in deep copy problem!
Thank you


This breaks printing of modules for conv layers. A quick fix is to add
to
_resetEDIT: Thanks for sharing your code :)