Skip to content

Instantly share code, notes, and snippets.

@yusuke0519
Last active February 2, 2021 10:01
Show Gist options
  • Select an option

  • Save yusuke0519/4945c213a49332d683c77203c62a4247 to your computer and use it in GitHub Desktop.

Select an option

Save yusuke0519/4945c213a49332d683c77203c62a4247 to your computer and use it in GitHub Desktop.

Revisions

  1. yusuke0519 revised this gist Feb 2, 2021. 1 changed file with 3 additions and 5 deletions.
    8 changes: 3 additions & 5 deletions get_activation_pytorch.py
    Original file line number Diff line number Diff line change
    @@ -17,14 +17,12 @@ def __init__(self):

    # add hook to store activatiosns
    self.activations = {}
    def get_activation(name):
    def hook(model, input, output):
    self.activations[name] = output
    return hook
    def store_activations(model, input, output):
    self.activations[model.__name__] = output

    for name, layer in self.model.named_children():
    layer.__name__ = name
    layer.register_forward_hook(get_activation(name))
    layer.register_forward_hook(store_activations)

    def forward(self, x):
    return self.model(x.view(-1, 784))
  2. yusuke0519 created this gist Feb 2, 2021.
    33 changes: 33 additions & 0 deletions get_activation_pytorch.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,33 @@
    class MLPEncoder(nn.Module):
    def __init__(self):
    super(MLPEncoder, self).__init__()

    # TODO: Fix hard coding
    self.model = nn.Sequential(OrderedDict([
    ('layer1', nn.Linear(784, 400)),
    ('relu1', nn.ReLU()),
    ('layer2', nn.Linear(400, 400)),
    ('relu2', nn.ReLU()),
    ('layer3', nn.Linear(400, 200)),
    ('relu3', nn.ReLU()),
    ('layer4', nn.Linear(200, 200)),
    ('relu4', nn.ReLU())
    ]))
    print(self.model)

    # add hook to store activatiosns
    self.activations = {}
    def get_activation(name):
    def hook(model, input, output):
    self.activations[name] = output
    return hook

    for name, layer in self.model.named_children():
    layer.__name__ = name
    layer.register_forward_hook(get_activation(name))

    def forward(self, x):
    return self.model(x.view(-1, 784))

    def get_activations(name):
    return self.activations[name]