Skip to content

Instantly share code, notes, and snippets.

@ajbrock
Created February 12, 2019 20:42
Show Gist options
  • Save ajbrock/d4a52ea75a7284ebafbccf87cf63414c to your computer and use it in GitHub Desktop.
Save ajbrock/d4a52ea75a7284ebafbccf87cf63414c to your computer and use it in GitHub Desktop.

Revisions

  1. ajbrock created this gist Feb 12, 2019.
    78 changes: 78 additions & 0 deletions my_batchnorm.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,78 @@
    # Manual BN
    # Calculate means and variances using mean-of-squares mins mean-squared
    def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
    # # Calculate expected value of x (m) and expected value of x**2 (m2)
    # Mean of x
    m = torch.mean(x, [0, 2, 3], keepdim=True)
    # Mean of x squared
    m2 = torch.mean(x ** 2, [0, 2, 3], keepdim=True)
    # Calculate variance as mean of squared minus mean squared.
    var = (m2 - m **2)

    if return_mean_var:
    return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
    else:
    return fused_bn(x, m, var, gain, bias, eps)

    # Apply scale and shift--if gain and bias are provided, fuse them here
    def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
    # Prepare scale
    scale = torch.rsqrt(var + eps)
    # If a gain is provided, use it
    if gain is not None:
    scale = scale * gain
    # Prepare shift
    shift = mean * scale
    # If bias is provided, use it
    if bias is not None:
    shift = shift - bias
    return x * scale - shift
    #return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias


    #### This is just a module wrapper that does bookkeeping, the above two functions are what do the batchnorming

    # My batchnorm, supports standing stats
    class myBN(nn.Module):
    def __init__(self, num_channels, eps=1e-5, momentum=0.1):
    super(myBN, self).__init__()
    # momentum for updating running stats
    self.momentum = momentum
    # epsilon to avoid dividing by 0
    self.eps = eps
    # Momentum
    self.momentum = momentum
    # Register buffers
    self.register_buffer('stored_mean', torch.zeros(num_channels))
    self.register_buffer('stored_var', torch.ones(num_channels))
    self.register_buffer('accumulation_counter', torch.zeros(1))
    # Accumulate running means and vars
    self.accumulate_standing = False

    # reset standing stats
    def reset_stats(self):
    self.stored_mean[:] = 0
    self.stored_var[:] = 0
    self.accumulation_counter[:] = 0

    def forward(self, x, gain, bias):
    if self.training:
    out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
    if self.accumulate_standing:
    self.stored_mean[:] = self.stored_mean + mean.data
    self.stored_var[:] = self.stored_var + var.data
    self.accumulation_counter += 1.0
    # If not accumulating standing stats, take running averages
    else:
    self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
    self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
    return out
    # If not in training mode, don't update stats
    else:
    mean = self.stored_mean.view(1, -1, 1, 1)
    var = self.stored_var.view(1, -1, 1, 1)
    # If using standing stats, divide them by the accumulation counter
    if self.accumulate_standing:
    mean = mean / self.accumulation_counter
    var = var / self.accumulation_counter
    return fused_bn(x, mean, var, gain, bias, self.eps)