Skip to content

Instantly share code, notes, and snippets.

@raytroop
Last active May 2, 2025 10:47
Show Gist options
  • Save raytroop/79435c6b5b8df722ae97cf5cdc1c64be to your computer and use it in GitHub Desktop.
Save raytroop/79435c6b5b8df722ae97cf5cdc1c64be to your computer and use it in GitHub Desktop.
try hook of pytorch
import torch
from torch import nn
# 1.
# register_forward_pre_hook(hook)
# The hook will be called every time before forward() is invoked. It should have the following signature:
# `hook(module, input) -> None`
# The hook should not modify the input
def pre_hook_wrapper(name):
def pre_hook(module, input):
print(name, ':')
print('input: ', input)
print('\n')
return pre_hook
# 2.
# register_forward_hook(hook)
# The hook will be called every time after forward() has computed an output. It should have the following signature:
# `hook(module, input, output) -> None`
# The hook should not modify the input or output.
def forward_hook_wrapper(name):
def forward_hook(module, input, output):
print(name, ':')
print('input: ', input)
print('output: ', output)
print('\n')
return forward_hook
# 3.
# register_backward_hook(hook)
# The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature:
# `hook(module, grad_input, grad_output) -> Tensor or None`
# The grad_input and grad_output may be tuples if the module has multiple inputs or outputs.
# The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of grad_input
# in subsequent computations.
def backward_hook_wrapper(name):
def backward_hook(module, grad_input, grad_output):
"""
grad_input: grad of module input
grad_output: grad of module output
"""
print(name, ':')
print('grad_input: ', grad_input)
print('grad_output: ', grad_output)
print('\n')
return backward_hook
# Example to classify hook's argument
model = torch.nn.ReLU()
handle_pre = model.register_forward_pre_hook(pre_hook_wrapper('forward_pre_hook'))
handle_forward = model.register_forward_hook(forward_hook_wrapper('forward_hook'))
handle_backward = model.register_backward_hook(backward_hook_wrapper('backward_hook'))
x = torch.tensor([-1.0, 1.0], requires_grad=True)
out = model(x)
print(out)
out = torch.sum(out)
print(out)
out.backward()
###### output #########
# forward_pre_hook :
# input: (tensor([-1., 1.], requires_grad=True),)
# forward_hook :
# input: (tensor([-1., 1.], requires_grad=True),)
# output: tensor([0., 1.], grad_fn=<ThresholdBackward0>)
# tensor([0., 1.], grad_fn=<ThresholdBackward0>)
# tensor(1., grad_fn=<SumBackward0>)
# backward_hook :
# grad_input: (tensor([0., 1.]),)
# grad_output: (tensor([1., 1.]),)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment