Skip to content

Instantly share code, notes, and snippets.

@pranavpandey2511
Created February 26, 2020 08:17
Show Gist options
  • Select an option

  • Save pranavpandey2511/31d006adcf57ec1ab7916f4579ad922d to your computer and use it in GitHub Desktop.

Select an option

Save pranavpandey2511/31d006adcf57ec1ab7916f4579ad922d to your computer and use it in GitHub Desktop.

Revisions

  1. pranavpandey2511 created this gist Feb 26, 2020.
    443 changes: 443 additions & 0 deletions cycle_gan_model.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,443 @@
    import torch
    import itertools
    from util.image_pool import ImagePool
    from .base_model import BaseModel
    from . import networks
    from PIL import Image
    import torch.nn.functional as F
    import numpy as np
    import matplotlib.pyplot as plt
    from models.vgg19 import VGG19
    from models.contextual_loss import contextual_loss as CX
    import os

    vgg19_path = os.path.join('.','vgg' , 'imagenet-vgg-verydeep-19.mat')

    class CycleGANModel(BaseModel):
    def name(self):
    return 'CycleGANModel'

    @staticmethod
    def modify_commandline_options(parser, is_train=True):
    # default CycleGAN did not use dropout
    parser.set_defaults(no_dropout=True)
    if is_train:
    parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
    parser.add_argument('--lambda_B', type=float, default=10.0,
    help='weight for cycle loss (B -> A -> B)')
    parser.add_argument('--lambda_identity', type=float, default=0.5,
    help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')

    parser.add_argument('--lambda_G_A', type=float, default=1.0, help='weight for G_a loss')

    parser.add_argument('--lambda_D_A', type=float, default=1.0, help='weight for D_a loss')

    parser.add_argument('--lambda_scale_G_A', type=float, default=10.0, help='weight for D_a loss')
    parser.add_argument('--lambda_scale_G_B', type=float, default=10.0, help='weight for D_a loss')

    parser.add_argument('--no_identity_b', action='store_true', help='if need to add identity_b to loss , otherwise it is zero')
    parser.add_argument('--l0_reg', action='store_true',
    help='if need to add lo_reg to loss , otherwise it is zero')

    parser.add_argument('--try_a', action='store_true',
    help='if need to add lo_reg to loss , otherwise it is zero')

    parser.add_argument('--contextual_loss', action='store_true',
    help='enable contextual loss')


    return parser

    def initialize(self, opt):
    BaseModel.initialize(self, opt)

    # specify the training losses you want to print out. The program will call base_model.get_current_losses
    self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B' , 'L0_reg' , 'scale_G_A' , 'scale_G_B' , 'contextual']
    # specify the images you want to save/display. The program will call base_model.get_current_visuals
    visual_names_A = ['real_A', 'fake_B', 'rec_A']
    visual_names_B = ['real_B', 'fake_A', 'rec_B']
    if self.isTrain and self.opt.lambda_identity > 0.0:
    visual_names_A.append('idt_A')
    visual_names_B.append('idt_B')

    self.visual_names = visual_names_A + visual_names_B
    # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
    if self.isTrain:
    self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
    else: # during test time, only load Gs
    self.model_names = ['G_A', 'G_B']

    # load/define networks
    # The naming conversion is different from those used in the paper
    # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
    self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
    not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids , is_a = True)
    self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
    not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids , is_a = False)

    if self.isTrain:
    use_sigmoid = opt.no_lsgan
    self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
    opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
    self.gpu_ids)
    self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
    opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
    self.gpu_ids)

    if self.isTrain:
    self.fake_A_pool = ImagePool(opt.pool_size)
    self.fake_B_pool = ImagePool(opt.pool_size)
    # define loss functions
    self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
    self.criterionCycle = torch.nn.L1Loss()
    self.criterionIdt = torch.nn.L1Loss()
    self.criterionScale = torch.nn.L1Loss()
    #self.l0norm = torch.sum() # torch.nn.L1Loss()
    # initialize optimizers
    self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
    lr=opt.lr, betas=(opt.beta1, opt.beta2))
    self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
    lr=opt.lr, betas=(opt.beta1, opt.beta2))

    if self.opt.contextual_loss:
    self.optimizer_G_contextual = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
    lr=opt.lr*25, betas=(opt.beta1, opt.beta2))

    self.optimizers = []
    self.optimizers.append(self.optimizer_G)
    self.optimizers.append(self.optimizer_D)

    if self.opt.contextual_loss:
    self.optimizers.append(self.optimizer_G_contextual)

    # self.grad_conv_filter = torch.Tensor([[[0, 1, 0], [1, -4, 1], [0, 1, 0]], [[0, 1, 0], [1, -4, 1], [0, 1, 0]],
    # [[0, 1, 0], [1, -4, 1], [0, 1, 0]]]).cuda().unsqueeze(0)

    #TODO put an if condition to use cuda only when --gpu_ids is >= 0
    if(len(opt.gpu_ids) > 0):
    self.grad_conv_filter = torch.Tensor([ [ [0, 1, 0], [1, -4, 1], [0, 1, 0]] ]).cuda().unsqueeze(0)
    else:
    self.grad_conv_filter = torch.Tensor([ [ [0, 1, 0], [1, -4, 1], [0, 1, 0]] ]).unsqueeze(0)


    #self.grad_conv_filter = torch.Tensor([ [ [0, 1, 0], [1, -4, 1], [0, 1, 0]] ]).cuda().unsqueeze(0)


    if (self.opt.contextual_loss and len(opt.gpu_ids) > 0):
    self.vgg19 = VGG19(vgg19_path).cuda()
    self.vgg19.eval()
    elif (self.opt.contextual_loss):
    self.vgg19 = VGG19(vgg19_path)
    self.vgg19.eval()

    def set_input(self, input):
    AtoB = self.opt.direction == 'AtoB'
    self.real_A = input['A' if AtoB else 'B'].to(self.device)
    self.real_B = input['B' if AtoB else 'A'].to(self.device)
    self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
    self.fake_B = self.netG_A(self.real_A)
    #self.rec_A = self.netG_B(self.fake_B)
    self.rec_A = self.real_A


    # if self.opt.try_a:
    # self.fake_A = self.real_B
    # self.rec_B = self.real_B
    # else:
    # self.fake_A = self.netG_B(self.real_B)
    # self.rec_B = self.netG_A(self.fake_A)

    self.fake_A = self.netG_B(self.real_B)
    self.rec_B = self.netG_A(self.fake_A)



    def calc_scale_loss(self, real,fake):
    list_of_scale = [1,2,4,8,16]
    scale_factor = [0.0001,0.001,0.01,0.1,1]

    #list_of_scale = [1]
    #scale_factor = [1]

    _ , __, orig_w, orig_h = real.shape
    loss_scale = 0

    for index, scale in enumerate(list_of_scale):


    scaled_w = int( orig_w / scale )
    scaled_h = int( orig_h / scale )
    scaled_real = F.adaptive_avg_pool3d(self.rgb2gray(real),(1,scaled_w,scaled_h)) #.resize((scaled_w, scaled_h), Image.BICUBIC)
    scaled_fake = F.adaptive_avg_pool3d(self.rgb2gray(fake),(1,scaled_w,scaled_h)) #.resize((scaled_w, scaled_h), Image.BICUBIC)

    grad_scaled_real = F.conv2d(scaled_real, self.grad_conv_filter, padding=1) #TODO padding
    grad_scaled_fake = F.conv2d(scaled_fake, self.grad_conv_filter, padding=1) # TODO padding

    # my_filter = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]]).cuda().unsqueeze(0)
    # image_filter = F.conv2d(scaled_real, my_filter, padding=1)
    # scaleed = image_filter / 9
    # use_filter = (scaleed < 0.3).type(torch.cuda.FloatTensor)
    # white = scaleed * use_filter + 1 * (1-use_filter)
    # white = 1 - white

    #grad_scaled_real.required_grad = False
    #curr_loss = scale_factor[index] * self.criterionScale(grad_scaled_fake * white, grad_scaled_real * white)
    curr_loss = scale_factor[index] * self.criterionScale(grad_scaled_fake , grad_scaled_real )
    loss_scale += curr_loss #TODO factor (best for now it's 10)
    #self.save_image2(grad_scaled_fake)


    return loss_scale
    #self.save_image2(white)


    def calc_contextual_loss(self,generated_image,target_image):

    generated_image_vgg19_layers = self.vgg19(generated_image)
    with torch.no_grad():
    target_image_vgg19_layers = self.vgg19(target_image)

    loss = 0

    lambdas = [] # scaling parameters for vgg layers
    num_elements = [] # number of elements in each vgg layer
    for img_layer in generated_image_vgg19_layers:
    num_elem = np.prod(img_layer.size()[1:])
    num_elements.append(num_elem)
    lambdas.append(1.0/num_elem)

    lambdas = [lamda / sum(lambdas) for lamda in lambdas]

    for i in range( len(generated_image_vgg19_layers) ):
    loss += lambdas[i].__float__() * CX(generated_image_vgg19_layers[i] , target_image_vgg19_layers[i])

    return loss

    def save_image2(self,output,file_name = 'yoav.png'):
    b, c, w, h = output.shape
    output = torch.clamp((output + 1)/2, 0, 1)
    #output = self.gray2rgb(output)
    output = output.permute(0,2,3,1)[0,:,:,:]
    I = output.data[:,:,0]
    I = I.cpu().numpy()
    I8 = (((I - I.min()) / (I.max() - I.min())) * 255.9).astype(np.uint8)
    img = Image.fromarray(I8)
    img.save(file_name)
    #picture = np.zeros((w, h, 3))
    # picture[:, :, 0] = output.data[0,0:1, :, :]
    # picture[:, :, 1] = output.data[0,1:2, :, :]
    # picture[:, :, 2] = output.data[0,2:3, :, :]
    #plt.imshow(output.data[:,:,0])
    #plt.savefig(file_name)

    def save_image(self,output,file_name = 'yoav.png'): #self.save_image(self.real_B[1,:,:,:].squeeze(0),'real.png')
    __, w, h = output.shape
    output = torch.clamp(output + 0.5, 0, 1)
    picture = np.zeros((w, h, 3))
    picture[:, :, 0] = output.data[0, :, :]
    picture[:, :, 1] = output.data[1, :, :]
    picture[:, :, 2] = output.data[2, :, :]
    plt.imshow(picture.data)
    plt.savefig(file_name)


    # grad = torch.abs(grad_scaled_real - grad_scaled_fake)
    # _, __, w, h = grad.shape
    # output = torch.clamp(grad, 0, 1)
    # picture = np.zeros((w, h, 3))
    # output = output.squeeze(0)
    # picture[:, :, 0] = output.data[0, :, :]
    # picture[:, :, 1] = output.data[0, :, :]
    # picture[:, :, 2] = output.data[0, :, :]
    # plt.imshow(picture.data)
    # file_name = 'yoav.png'
    # plt.savefig(file_name)
    #self.save_image(real.squeeze(0))


    # def backward_D_new(self,netD,fake):
    #
    # pred_fake = netD(fake)
    # loss_D_fake = self.criterionGAN(pred_fake, False)
    # # Combined loss
    # loss_D = (loss_D_fake) * 0.5
    # # backward
    # loss_D.backward()
    #
    # return loss_D

    def backward_D_basic(self, netD, real, fake,g_real = None, do_another_value = False):
    # Real
    pred_real = netD(real)
    loss_D_real = self.criterionGAN(pred_real, True)
    # Fake
    pred_fake = netD(fake.detach())
    loss_D_fake = self.criterionGAN(pred_fake, False)

    if (do_another_value): #TODO
    loss_D_real2 = self.criterionGAN(netD(g_real.detach()), True) #TODO make sure it helps
    else:
    loss_D_real2 = 0

    # Combined loss
    loss_D = (loss_D_real + loss_D_real2) * 0.5 + loss_D_fake
    # backward
    loss_D.backward()

    return loss_D

    def backward_D_A(self):
    fake_B = self.fake_B_pool.query(self.fake_B)
    self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, self.idt_A, do_another_value=True) * self.opt.lambda_D_A * 0.5

    def backward_D_B(self): ##think

    fake_A = self.fake_A_pool.query(self.fake_A)
    self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    # if self.opt.try_a:
    # self.loss_D_B = 0
    # else:
    # fake_A = self.fake_A_pool.query(self.fake_A)
    # self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G_contextual_loss(self):
    self.loss_contextual = self.calc_contextual_loss(self.fake_B, self.real_B) #TODO consider add for input image

    # combined loss
    self.loss_G = self.loss_contextual # self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_L0_reg + self.loss_scale
    self.loss_G.backward()

    def backward_G(self):



    lambda_idt = self.opt.lambda_identity
    lambda_A = self.opt.lambda_A
    lambda_B = self.opt.lambda_B
    # Identity loss
    if lambda_idt > 0:
    # G_A should be identity if real_B is fed.
    self.idt_A = self.netG_A(self.real_B)
    self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
    # G_B should be identity if real_A is fed.
    if self.opt.try_a == False and self.opt.no_identity_b:
    self.idt_B = self.netG_B(self.real_A)
    self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
    else :
    self.idt_B = self.real_A
    self.loss_idt_B = 0 #identical b removed because we expect the noise generator won't output same noisy input
    else:
    self.loss_idt_A = 0
    self.loss_idt_B = 0

    if (self.opt.l0_reg) :
    #self.loss_L0_reg = 0.000003 * torch.sum(-1 * torch.clamp(self.fake_B,-0.5,0.5) + 0.5) # TODO add parameter

    image = -1 * torch.clamp( self.rgb2gray( self.fake_B), -0.5, 0.5) + 0.5
    mask_toward_zero = image.clone()
    mask_toward_one = image.clone()

    mask_toward_zero[mask_toward_zero > 0.5] = 0
    mask_toward_one[mask_toward_one < 0.5] = 1

    self.loss_L0_reg = 0.0001 *( torch.sum( mask_toward_zero ) + torch.sum( 1 - mask_toward_one ) ) # TODO add parameter

    else:
    self.loss_L0_reg = 0


    self.loss_scale_G_A = self.opt.lambda_scale_G_A * self.calc_scale_loss(self.real_A,self.fake_B)
    self.loss_scale_G_B = self.opt.lambda_scale_G_B * self.calc_scale_loss(self.real_B, self.fake_A)

    # GAN loss D_A(G_A(A))
    self.loss_G_A = ( self.criterionGAN(self.netD_A(self.fake_B), True) )* self.opt.lambda_G_A + self.criterionGAN(self.netD_A(self.idt_A), True)
    # GAN loss D_B(G_B(B))


    self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) #think

    # if self.opt.try_a:
    # self.loss_G_B = 0
    # else:
    # self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)

    #Forward cycle loss
    # self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
    # # Backward cycle loss
    # if self.opt.try_a:
    # self.loss_cycle_B = 0
    # else:
    # self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B

    # Forward cycle loss
    self.loss_cycle_A = 0
    self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B


    # combined loss
    self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_L0_reg #+ self.loss_scale_G_A + self.loss_scale_G_B
    self.loss_G.backward()

    def rgb2gray(self, rgb):

    r, g, b = rgb[:,0:1, :, :], rgb[:,1:2, :, :], rgb[:,2:3, :, :]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    return gray

    def gray2rgb(self, gray):
    batch, c, w, h = gray.shape

    if (c == 1):
    #print("convert to rgb")
    r, g, b = 0.2989 * gray, 0.5870 * gray, 0.1140 * gray


    rgb = torch.zeros((batch , 3,w, h )).cuda()
    rgb[:, 0:1, :, :] = r
    rgb[:, 1:2, :, :] = g
    rgb[:, 2:3, :, :] = b

    return rgb
    else:
    return gray


    def print_layer_grad(self, initial_print):
    model = self.netG_A
    modules_list = list(model.modules())
    layer_list = [x for x in modules_list if isinstance(x, torch.nn.Conv2d) or isinstance(x, torch.nn.Linear) ]
    grad_list = []
    for layer in layer_list:
    if isinstance(layer, torch.nn.Conv2d) or isinstance(layer, torch.nn.Linear):
    grad_list.append( torch.norm(layer._parameters['weight'].grad, 1).__float__() )

    print (initial_print , grad_list)

    def optimize_parameters(self):
    # forward
    self.forward()
    # G_A and G_B
    self.set_requires_grad([self.netD_A, self.netD_B], False)
    self.optimizer_G.zero_grad()
    self.backward_G()
    #self.print_layer_grad("regular ")
    self.optimizer_G.step()
    # D_A and D_B
    self.set_requires_grad([self.netD_A, self.netD_B], True)
    self.optimizer_D.zero_grad()
    self.backward_D_A()
    self.backward_D_B()
    self.optimizer_D.step()

    self.loss_contextual = 0
    if self.opt.contextual_loss:
    torch.cuda.empty_cache()
    self.fake_B = self.netG_A(self.real_A)
    self.optimizer_G_contextual.zero_grad()
    self.backward_G_contextual_loss()
    #self.print_layer_grad("con ")
    self.optimizer_G_contextual.step()