import torch import torch.nn as nn class ResnetBlock(nn.Module): def __init__(self, dim): super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim) def build_conv_block(self, dim): conv_block = [] conv_block += [nn.ReflectionPad2d(1)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True), nn.InstanceNorm2d(dim), nn.ReLU(True)] conv_block += [nn.ReflectionPad2d(1)] conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True), nn.InstanceNorm2d(dim)] return nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) return out class ResnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, n_blocks=6): assert(n_blocks >= 0) super(ResnetGenerator, self).__init__() self.input_nc = input_nc self.output_nc = output_nc self.ngf = ngf model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True), nn.InstanceNorm2d(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): mult = 2**i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=True), nn.InstanceNorm2d(ngf * mult * 2), nn.ReLU(True)] mult = 2**n_downsampling for i in range(n_blocks): model += [ResnetBlock(ngf * mult)] for i in range(n_downsampling): mult = 2**(n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), nn.InstanceNorm2d(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): return self.model(input) if __name__ == '__main__': from PIL import Image from torchvision import transforms import sys model_path = sys.argv[1] image_path = sys.argv[2] input_nc = 3 output_nc = 3 ngf = 64 n_blocks = 9 netG = ResnetGenerator(input_nc, output_nc, ngf, n_blocks=n_blocks) netG.load_state_dict(torch.load(model_path)) netG.eval() preprocess = transforms.Compose([ transforms.Resize(256), transforms.ToTensor(), ]) img = Image.open(image_path) img_t = preprocess(img) input = torch.autograd.Variable(torch.unsqueeze(img_t, 0)) out = netG(input) out_t = (out.data.squeeze() + 1.0) / 2.0 out_img = transforms.ToPILImage()(out_t) out_img.show()