Last active
May 2, 2025 10:47
-
-
Save raytroop/79435c6b5b8df722ae97cf5cdc1c64be to your computer and use it in GitHub Desktop.
try hook of pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| # 1. | |
| # register_forward_pre_hook(hook) | |
| # The hook will be called every time before forward() is invoked. It should have the following signature: | |
| # `hook(module, input) -> None` | |
| # The hook should not modify the input | |
| def pre_hook_wrapper(name): | |
| def pre_hook(module, input): | |
| print(name, ':') | |
| print('input: ', input) | |
| print('\n') | |
| return pre_hook | |
| # 2. | |
| # register_forward_hook(hook) | |
| # The hook will be called every time after forward() has computed an output. It should have the following signature: | |
| # `hook(module, input, output) -> None` | |
| # The hook should not modify the input or output. | |
| def forward_hook_wrapper(name): | |
| def forward_hook(module, input, output): | |
| print(name, ':') | |
| print('input: ', input) | |
| print('output: ', output) | |
| print('\n') | |
| return forward_hook | |
| # 3. | |
| # register_backward_hook(hook) | |
| # The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature: | |
| # `hook(module, grad_input, grad_output) -> Tensor or None` | |
| # The grad_input and grad_output may be tuples if the module has multiple inputs or outputs. | |
| # The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of grad_input | |
| # in subsequent computations. | |
| def backward_hook_wrapper(name): | |
| def backward_hook(module, grad_input, grad_output): | |
| """ | |
| grad_input: grad of module input | |
| grad_output: grad of module output | |
| """ | |
| print(name, ':') | |
| print('grad_input: ', grad_input) | |
| print('grad_output: ', grad_output) | |
| print('\n') | |
| return backward_hook | |
| # Example to classify hook's argument | |
| model = torch.nn.ReLU() | |
| handle_pre = model.register_forward_pre_hook(pre_hook_wrapper('forward_pre_hook')) | |
| handle_forward = model.register_forward_hook(forward_hook_wrapper('forward_hook')) | |
| handle_backward = model.register_backward_hook(backward_hook_wrapper('backward_hook')) | |
| x = torch.tensor([-1.0, 1.0], requires_grad=True) | |
| out = model(x) | |
| print(out) | |
| out = torch.sum(out) | |
| print(out) | |
| out.backward() | |
| ###### output ######### | |
| # forward_pre_hook : | |
| # input: (tensor([-1., 1.], requires_grad=True),) | |
| # forward_hook : | |
| # input: (tensor([-1., 1.], requires_grad=True),) | |
| # output: tensor([0., 1.], grad_fn=<ThresholdBackward0>) | |
| # tensor([0., 1.], grad_fn=<ThresholdBackward0>) | |
| # tensor(1., grad_fn=<SumBackward0>) | |
| # backward_hook : | |
| # grad_input: (tensor([0., 1.]),) | |
| # grad_output: (tensor([1., 1.]),) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment