Skip to content

Instantly share code, notes, and snippets.

@raytroop
Last active May 2, 2025 10:47
Show Gist options
  • Save raytroop/79435c6b5b8df722ae97cf5cdc1c64be to your computer and use it in GitHub Desktop.
Save raytroop/79435c6b5b8df722ae97cf5cdc1c64be to your computer and use it in GitHub Desktop.

Revisions

  1. raytroop revised this gist Sep 26, 2018. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion hook_pth.py
    Original file line number Diff line number Diff line change
    @@ -65,7 +65,7 @@ def backward_hook(module, grad_input, grad_output):
    print(out)
    out.backward()


    ###### output #########
    # forward_pre_hook :
    # input: (tensor([-1., 1.], requires_grad=True),)

  2. raytroop revised this gist Sep 26, 2018. 1 changed file with 2 additions and 0 deletions.
    2 changes: 2 additions & 0 deletions hook_pth.py
    Original file line number Diff line number Diff line change
    @@ -68,9 +68,11 @@ def backward_hook(module, grad_input, grad_output):

    # forward_pre_hook :
    # input: (tensor([-1., 1.], requires_grad=True),)

    # forward_hook :
    # input: (tensor([-1., 1.], requires_grad=True),)
    # output: tensor([0., 1.], grad_fn=<ThresholdBackward0>)

    # tensor([0., 1.], grad_fn=<ThresholdBackward0>)
    # tensor(1., grad_fn=<SumBackward0>)
    # backward_hook :
  3. raytroop created this gist Sep 26, 2018.
    78 changes: 78 additions & 0 deletions hook_pth.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,78 @@
    import torch
    from torch import nn

    # 1.
    # register_forward_pre_hook(hook)
    # The hook will be called every time before forward() is invoked. It should have the following signature:
    # `hook(module, input) -> None`

    # The hook should not modify the input

    def pre_hook_wrapper(name):
    def pre_hook(module, input):
    print(name, ':')
    print('input: ', input)
    print('\n')
    return pre_hook

    # 2.
    # register_forward_hook(hook)
    # The hook will be called every time after forward() has computed an output. It should have the following signature:
    # `hook(module, input, output) -> None`

    # The hook should not modify the input or output.

    def forward_hook_wrapper(name):
    def forward_hook(module, input, output):
    print(name, ':')
    print('input: ', input)
    print('output: ', output)
    print('\n')
    return forward_hook

    # 3.
    # register_backward_hook(hook)
    # The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature:
    # `hook(module, grad_input, grad_output) -> Tensor or None`

    # The grad_input and grad_output may be tuples if the module has multiple inputs or outputs.
    # The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of grad_input
    # in subsequent computations.

    def backward_hook_wrapper(name):
    def backward_hook(module, grad_input, grad_output):
    """
    grad_input: grad of module input
    grad_output: grad of module output
    """
    print(name, ':')
    print('grad_input: ', grad_input)
    print('grad_output: ', grad_output)
    print('\n')
    return backward_hook

    # Example to classify hook's argument

    model = torch.nn.ReLU()
    handle_pre = model.register_forward_pre_hook(pre_hook_wrapper('forward_pre_hook'))
    handle_forward = model.register_forward_hook(forward_hook_wrapper('forward_hook'))
    handle_backward = model.register_backward_hook(backward_hook_wrapper('backward_hook'))

    x = torch.tensor([-1.0, 1.0], requires_grad=True)
    out = model(x)
    print(out)
    out = torch.sum(out)
    print(out)
    out.backward()


    # forward_pre_hook :
    # input: (tensor([-1., 1.], requires_grad=True),)
    # forward_hook :
    # input: (tensor([-1., 1.], requires_grad=True),)
    # output: tensor([0., 1.], grad_fn=<ThresholdBackward0>)
    # tensor([0., 1.], grad_fn=<ThresholdBackward0>)
    # tensor(1., grad_fn=<SumBackward0>)
    # backward_hook :
    # grad_input: (tensor([0., 1.]),)
    # grad_output: (tensor([1., 1.]),)