## Weight norm is now added to pytorch as a pre-hook, so use that instead :) import torch import torch.nn as nn from torch.nn import Parameter from functools import wraps class WeightNorm(nn.Module): append_g = '_g' append_v = '_v' def __init__(self, module, weights): super(WeightNorm, self).__init__() self.module = module self.weights = weights self._reset() def _reset(self): for name_w in self.weights: w = getattr(self.module, name_w) # construct g,v such that w = g/||v|| * v g = torch.norm(w) v = w/g.expand_as(w) g = Parameter(g.data) v = Parameter(v.data) name_g = name_w + self.append_g name_v = name_w + self.append_v # remove w from parameter list del self.module._parameters[name_w] # add g and v as new parameters self.module.register_parameter(name_g, g) self.module.register_parameter(name_v, v) def _setweights(self): for name_w in self.weights: name_g = name_w + self.append_g name_v = name_w + self.append_v g = getattr(self.module, name_g) v = getattr(self.module, name_v) w = v*(g/torch.norm(v)).expand_as(v) setattr(self.module, name_w, w) def forward(self, *args): self._setweights() return self.module.forward(*args) ############################################################## ## An older version using a python decorator but might be buggy. ## Does not work when the module is replicated (e.g. nn.DataParallel) def _decorate(forward, module, name, name_g, name_v): @wraps(forward) def decorated_forward(*args, **kwargs): g = module.__getattr__(name_g) v = module.__getattr__(name_v) w = v*(g/torch.norm(v)).expand_as(v) module.__setattr__(name, w) return forward(*args, **kwargs) return decorated_forward def weight_norm(module, name): param = module.__getattr__(name) # construct g,v such that w = g/||v|| * v g = torch.norm(param) v = param/g.expand_as(param) g = Parameter(g.data) v = Parameter(v.data) name_g = name + '_g' name_v = name + '_v' # remove w from parameter list del module._parameters[name] # add g and v as new parameters module.register_parameter(name_g, g) module.register_parameter(name_v, v) # construct w every time before forward is called module.forward = _decorate(module.forward, module, name, name_g, name_v) return module