from collections import OrderedDict from typing import Callable, Dict, Optional from warnings import warn import torch def _remove_all_forward_hooks( module: torch.nn.Module, hook_fn_name: Optional[str] = None ) -> None: """ This function removes all forward hooks in the specified module, without requiring any hook handles. This lets us clean up & remove any hooks that weren't property deleted. Warning: Various PyTorch modules and systems make use of hooks, and thus extreme caution should be exercised when removing all hooks. Users are recommended to give their hook function a unique name that can be used to safely identify and remove the target forward hooks. Args: module (nn.Module): The module instance to remove forward hooks from. hook_fn_name (str, optional): Optionally only remove specific forward hooks based on their function's __name__ attribute. Default: None """ if hook_fn_name is None: warn("Removing all active hooks will break some PyTorch modules & systems.") def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: if hasattr(module, "_forward_hooks"): if m._forward_hooks != OrderedDict(): if name is not None: dict_items = list(m._forward_hooks.items()) m._forward_hooks = OrderedDict( [(i, fn) for i, fn in dict_items if fn.__name__ != name] ) else: m._forward_hooks: Dict[int, Callable] = OrderedDict() def _remove_child_hooks( target_module: torch.nn.Module, hook_name: Optional[str] = None ) -> None: for name, child in target_module._modules.items(): if child is not None: _remove_hooks(child, hook_name) _remove_child_hooks(child, hook_name) # Remove hooks from target submodules _remove_child_hooks(module, hook_fn_name) # Remove hooks from the target module _remove_hooks(module, hook_fn_name) from collections import OrderedDict from typing import List, Optional import torch def _count_forward_hooks( module: torch.nn.Module, hook_fn_name: Optional[str] = None ) -> int: """ Count the number of active forward hooks on the specified module instance. Args: module (nn.Module): The model module instance to count the number of forward hooks on. name (str, optional): Optionally only count specific forward hooks based on their function's __name__ attribute. Default: None Returns: num_hooks (int): The number of active hooks in the specified module. """ num_hooks: List[int] = [0] def _count_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: if hasattr(m, "_forward_hooks"): if m._forward_hooks != OrderedDict(): dict_items = list(m._forward_hooks.items()) for i, fn in dict_items: if hook_fn_name is None or fn.__name__ == name: num_hooks[0] += 1 def _count_child_hooks( target_module: torch.nn.Module, hook_name: Optional[str] = None, ) -> None: for name, child in target_module._modules.items(): if child is not None: _count_hooks(child, hook_name) _count_child_hooks(child, hook_name) _count_child_hooks(module, hook_fn_name) _count_hooks(module, hook_fn_name) return num_hooks[0]