Created
February 12, 2019 20:42
-
-
Save ajbrock/d4a52ea75a7284ebafbccf87cf63414c to your computer and use it in GitHub Desktop.
Revisions
-
ajbrock created this gist
Feb 12, 2019 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,78 @@ # Manual BN # Calculate means and variances using mean-of-squares mins mean-squared def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): # # Calculate expected value of x (m) and expected value of x**2 (m2) # Mean of x m = torch.mean(x, [0, 2, 3], keepdim=True) # Mean of x squared m2 = torch.mean(x ** 2, [0, 2, 3], keepdim=True) # Calculate variance as mean of squared minus mean squared. var = (m2 - m **2) if return_mean_var: return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() else: return fused_bn(x, m, var, gain, bias, eps) # Apply scale and shift--if gain and bias are provided, fuse them here def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): # Prepare scale scale = torch.rsqrt(var + eps) # If a gain is provided, use it if gain is not None: scale = scale * gain # Prepare shift shift = mean * scale # If bias is provided, use it if bias is not None: shift = shift - bias return x * scale - shift #return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias #### This is just a module wrapper that does bookkeeping, the above two functions are what do the batchnorming # My batchnorm, supports standing stats class myBN(nn.Module): def __init__(self, num_channels, eps=1e-5, momentum=0.1): super(myBN, self).__init__() # momentum for updating running stats self.momentum = momentum # epsilon to avoid dividing by 0 self.eps = eps # Momentum self.momentum = momentum # Register buffers self.register_buffer('stored_mean', torch.zeros(num_channels)) self.register_buffer('stored_var', torch.ones(num_channels)) self.register_buffer('accumulation_counter', torch.zeros(1)) # Accumulate running means and vars self.accumulate_standing = False # reset standing stats def reset_stats(self): self.stored_mean[:] = 0 self.stored_var[:] = 0 self.accumulation_counter[:] = 0 def forward(self, x, gain, bias): if self.training: out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) if self.accumulate_standing: self.stored_mean[:] = self.stored_mean + mean.data self.stored_var[:] = self.stored_var + var.data self.accumulation_counter += 1.0 # If not accumulating standing stats, take running averages else: self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum return out # If not in training mode, don't update stats else: mean = self.stored_mean.view(1, -1, 1, 1) var = self.stored_var.view(1, -1, 1, 1) # If using standing stats, divide them by the accumulation counter if self.accumulate_standing: mean = mean / self.accumulation_counter var = var / self.accumulation_counter return fused_bn(x, mean, var, gain, bias, self.eps)