Skip to content

Instantly share code, notes, and snippets.

@yongjun823
Created September 18, 2019 11:27
Show Gist options
  • Select an option

  • Save yongjun823/ee387e8f3b424c310bbff80222c0d6e0 to your computer and use it in GitHub Desktop.

Select an option

Save yongjun823/ee387e8f3b424c310bbff80222c0d6e0 to your computer and use it in GitHub Desktop.

Revisions

  1. yongjun823 revised this gist Sep 18, 2019. 1 changed file with 0 additions and 2 deletions.
    2 changes: 0 additions & 2 deletions pr_vgg.py
    Original file line number Diff line number Diff line change
    @@ -24,8 +24,6 @@ def __init__(self):

    def forward(self, x):
    x = self.net(x)
    # x = torch.flatten(x, 1)
    # x = self.classifier(x)

    return x

  2. yongjun823 created this gist Sep 18, 2019.
    42 changes: 42 additions & 0 deletions pr_vgg.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,42 @@
    import torch
    import torch.nn as nn
    import torchvision.models as models
    from pprint import pprint

    class Net(nn.Module):
    def __init__(self):
    super().__init__()
    model = models.vgg19(pretrained=False)

    pprint(list(model.children()))

    model = list(model.children())[:-2]
    model = model[0]
    model = list(model.children())

    vgg_arr = []

    for xx in model:
    if 'Max' not in xx.__class__.__name__:
    vgg_arr.append(xx)

    self.net = nn.Sequential(*vgg_arr)

    def forward(self, x):
    x = self.net(x)
    # x = torch.flatten(x, 1)
    # x = self.classifier(x)

    return x


    device = torch.device("cuda:0")

    model = Net().to(device)
    model.eval()

    t = torch.randn((1, 3, 224, 224)).to(device)

    out = model(t)

    print(out.shape)