from torch import nn import torch from typing import List def filter_parameters_for_finetuning(module: nn.Module) -> List[torch.Tensor]: """ Args: module: A :py:class:`nn.Module` object where some of the children may have a boolean attribute ``finetune``, which if it exists and is ``False``, will exclude parameters from this submodule from the result. Returns: A list of parameters in the module for finetuning. """ params = [] # We're going to look at each direct child of the current module and recurse into # those that don't have ``finetune=False``. for child in module.children(): if hasattr(child, 'finetune') and not child.finetune: # We don't recurse into this part of the module subtree since we want to # freeze all these parameters continue # but if the child is going to be finetuned, we need to add all parameters # declared in the child params.extend(child.parameters(recurse=False)) # and all of its children which also have ``finetune=True`` params.extend(filter_parameters_for_finetuning(child)) return params def demo(): class SubSubSubModule(nn.Module): def __init__(self): super().__init__() self.lin = nn.Linear(5, 7) class SubSubModule(nn.Module): def __init__(self): super().__init__() self.lin = nn.Linear(10, 15) self.m = SubSubSubModule() self.finetune = False class SubModule(nn.Module): def __init__(self): super().__init__() self.lin = nn.Linear(20, 30) self.m = SubSubModule() class Net(nn.Module): def __init__(self): super().__init__() self.lin = nn.Linear(40, 60) self.m = SubModule() net = Net() print([ param.shape for param in filter_parameters_for_finetuning(net) ]) # Outputs # [torch.Size([60, 40]), <- Net.lin.weight # torch.Size([60]), <- Net.lin.bias # torch.Size([30, 20]), <- Net.m.lin.weight # torch.Size([30])] <- Net.m.lin.bias