Last active
May 2, 2025 10:47
-
-
Save raytroop/79435c6b5b8df722ae97cf5cdc1c64be to your computer and use it in GitHub Desktop.
Revisions
-
raytroop revised this gist
Sep 26, 2018 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -65,7 +65,7 @@ def backward_hook(module, grad_input, grad_output): print(out) out.backward() ###### output ######### # forward_pre_hook : # input: (tensor([-1., 1.], requires_grad=True),) -
raytroop revised this gist
Sep 26, 2018 . 1 changed file with 2 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -68,9 +68,11 @@ def backward_hook(module, grad_input, grad_output): # forward_pre_hook : # input: (tensor([-1., 1.], requires_grad=True),) # forward_hook : # input: (tensor([-1., 1.], requires_grad=True),) # output: tensor([0., 1.], grad_fn=<ThresholdBackward0>) # tensor([0., 1.], grad_fn=<ThresholdBackward0>) # tensor(1., grad_fn=<SumBackward0>) # backward_hook : -
raytroop created this gist
Sep 26, 2018 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,78 @@ import torch from torch import nn # 1. # register_forward_pre_hook(hook) # The hook will be called every time before forward() is invoked. It should have the following signature: # `hook(module, input) -> None` # The hook should not modify the input def pre_hook_wrapper(name): def pre_hook(module, input): print(name, ':') print('input: ', input) print('\n') return pre_hook # 2. # register_forward_hook(hook) # The hook will be called every time after forward() has computed an output. It should have the following signature: # `hook(module, input, output) -> None` # The hook should not modify the input or output. def forward_hook_wrapper(name): def forward_hook(module, input, output): print(name, ':') print('input: ', input) print('output: ', output) print('\n') return forward_hook # 3. # register_backward_hook(hook) # The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature: # `hook(module, grad_input, grad_output) -> Tensor or None` # The grad_input and grad_output may be tuples if the module has multiple inputs or outputs. # The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of grad_input # in subsequent computations. def backward_hook_wrapper(name): def backward_hook(module, grad_input, grad_output): """ grad_input: grad of module input grad_output: grad of module output """ print(name, ':') print('grad_input: ', grad_input) print('grad_output: ', grad_output) print('\n') return backward_hook # Example to classify hook's argument model = torch.nn.ReLU() handle_pre = model.register_forward_pre_hook(pre_hook_wrapper('forward_pre_hook')) handle_forward = model.register_forward_hook(forward_hook_wrapper('forward_hook')) handle_backward = model.register_backward_hook(backward_hook_wrapper('backward_hook')) x = torch.tensor([-1.0, 1.0], requires_grad=True) out = model(x) print(out) out = torch.sum(out) print(out) out.backward() # forward_pre_hook : # input: (tensor([-1., 1.], requires_grad=True),) # forward_hook : # input: (tensor([-1., 1.], requires_grad=True),) # output: tensor([0., 1.], grad_fn=<ThresholdBackward0>) # tensor([0., 1.], grad_fn=<ThresholdBackward0>) # tensor(1., grad_fn=<SumBackward0>) # backward_hook : # grad_input: (tensor([0., 1.]),) # grad_output: (tensor([1., 1.]),)