Last active
          March 14, 2018 23:19 
        
      - 
      
 - 
        
Save ajbrock/26ae0dfe80167667af010ae3708ca276 to your computer and use it in GitHub Desktop.  
Revisions
- 
        
ajbrock revised this gist
Mar 14, 2018 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -8,7 +8,7 @@ def count_conv_flops(self, input, output): flops_c = self.out_channels * self.in_channels / self.groups # Flop contribution from number of spatial locations we convolve over flops_s = output.size(2) * output.size(3) # Flop contribution from number of mult-adds at each location flops_f = self.kernel_size[0] * self.kernel_size[1] data_dict['conv_flops'] += flops_c * flops_s * flops_f return  - 
        
ajbrock created this gist
Mar 14, 2018 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,32 @@ import torch # Dict to store hooks and flop count data_dict = {'conv_flops' : 0, 'hooks' :[]} def count_conv_flops(self, input, output): # Flop contribution from channelwise connections flops_c = self.out_channels * self.in_channels / self.groups # Flop contribution from number of spatial locations we convolve over flops_s = output.size(2) * output.size(3) # Flop contribution from number of mult-adds at flops_f = self.kernel_size[0] * self.kernel_size[1] data_dict['conv_flops'] += flops_c * flops_s * flops_f return def add_hooks(m): if isinstance(m, torch.nn.Conv2d): data_dict['hooks'] += [m.register_forward_hook(count_conv_flops)] return def count_flops(model, x): data_dict['conv_flops'] = 0 # Note if we need to return the model to training mode set_train = model.training model.eval() model.apply(add_hooks) out = model(torch.autograd.Variable(x.data, volatile=True)) for hook in data_dict['hooks']: hook.remove() if set_train: model.train() return data_dict['conv_flops']