Skip to content

Instantly share code, notes, and snippets.

@ajbrock
Last active March 14, 2018 23:19
Show Gist options
  • Save ajbrock/26ae0dfe80167667af010ae3708ca276 to your computer and use it in GitHub Desktop.
Save ajbrock/26ae0dfe80167667af010ae3708ca276 to your computer and use it in GitHub Desktop.

Revisions

  1. ajbrock revised this gist Mar 14, 2018. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion count_flops.py
    Original file line number Diff line number Diff line change
    @@ -8,7 +8,7 @@ def count_conv_flops(self, input, output):
    flops_c = self.out_channels * self.in_channels / self.groups
    # Flop contribution from number of spatial locations we convolve over
    flops_s = output.size(2) * output.size(3)
    # Flop contribution from number of mult-adds at
    # Flop contribution from number of mult-adds at each location
    flops_f = self.kernel_size[0] * self.kernel_size[1]
    data_dict['conv_flops'] += flops_c * flops_s * flops_f
    return
  2. ajbrock created this gist Mar 14, 2018.
    32 changes: 32 additions & 0 deletions count_flops.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,32 @@
    import torch

    # Dict to store hooks and flop count
    data_dict = {'conv_flops' : 0, 'hooks' :[]}

    def count_conv_flops(self, input, output):
    # Flop contribution from channelwise connections
    flops_c = self.out_channels * self.in_channels / self.groups
    # Flop contribution from number of spatial locations we convolve over
    flops_s = output.size(2) * output.size(3)
    # Flop contribution from number of mult-adds at
    flops_f = self.kernel_size[0] * self.kernel_size[1]
    data_dict['conv_flops'] += flops_c * flops_s * flops_f
    return

    def add_hooks(m):
    if isinstance(m, torch.nn.Conv2d):
    data_dict['hooks'] += [m.register_forward_hook(count_conv_flops)]
    return

    def count_flops(model, x):
    data_dict['conv_flops'] = 0
    # Note if we need to return the model to training mode
    set_train = model.training
    model.eval()
    model.apply(add_hooks)
    out = model(torch.autograd.Variable(x.data, volatile=True))
    for hook in data_dict['hooks']:
    hook.remove()
    if set_train:
    model.train()
    return data_dict['conv_flops']