import torch from torch import nn # 1. # register_forward_pre_hook(hook) # The hook will be called every time before forward() is invoked. It should have the following signature: # `hook(module, input) -> None` # The hook should not modify the input def pre_hook_wrapper(name): def pre_hook(module, input): print(name, ':') print('input: ', input) print('\n') return pre_hook # 2. # register_forward_hook(hook) # The hook will be called every time after forward() has computed an output. It should have the following signature: # `hook(module, input, output) -> None` # The hook should not modify the input or output. def forward_hook_wrapper(name): def forward_hook(module, input, output): print(name, ':') print('input: ', input) print('output: ', output) print('\n') return forward_hook # 3. # register_backward_hook(hook) # The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature: # `hook(module, grad_input, grad_output) -> Tensor or None` # The grad_input and grad_output may be tuples if the module has multiple inputs or outputs. # The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of grad_input # in subsequent computations. def backward_hook_wrapper(name): def backward_hook(module, grad_input, grad_output): """ grad_input: grad of module input grad_output: grad of module output """ print(name, ':') print('grad_input: ', grad_input) print('grad_output: ', grad_output) print('\n') return backward_hook # Example to classify hook's argument model = torch.nn.ReLU() handle_pre = model.register_forward_pre_hook(pre_hook_wrapper('forward_pre_hook')) handle_forward = model.register_forward_hook(forward_hook_wrapper('forward_hook')) handle_backward = model.register_backward_hook(backward_hook_wrapper('backward_hook')) x = torch.tensor([-1.0, 1.0], requires_grad=True) out = model(x) print(out) out = torch.sum(out) print(out) out.backward() ###### output ######### # forward_pre_hook : # input: (tensor([-1., 1.], requires_grad=True),) # forward_hook : # input: (tensor([-1., 1.], requires_grad=True),) # output: tensor([0., 1.], grad_fn=) # tensor([0., 1.], grad_fn=) # tensor(1., grad_fn=) # backward_hook : # grad_input: (tensor([0., 1.]),) # grad_output: (tensor([1., 1.]),)