Created
February 26, 2020 08:17
-
-
Save pranavpandey2511/31d006adcf57ec1ab7916f4579ad922d to your computer and use it in GitHub Desktop.
Revisions
-
pranavpandey2511 created this gist
Feb 26, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,443 @@ import torch import itertools from util.image_pool import ImagePool from .base_model import BaseModel from . import networks from PIL import Image import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt from models.vgg19 import VGG19 from models.contextual_loss import contextual_loss as CX import os vgg19_path = os.path.join('.','vgg' , 'imagenet-vgg-verydeep-19.mat') class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' @staticmethod def modify_commandline_options(parser, is_train=True): # default CycleGAN did not use dropout parser.set_defaults(no_dropout=True) if is_train: parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') parser.add_argument('--lambda_G_A', type=float, default=1.0, help='weight for G_a loss') parser.add_argument('--lambda_D_A', type=float, default=1.0, help='weight for D_a loss') parser.add_argument('--lambda_scale_G_A', type=float, default=10.0, help='weight for D_a loss') parser.add_argument('--lambda_scale_G_B', type=float, default=10.0, help='weight for D_a loss') parser.add_argument('--no_identity_b', action='store_true', help='if need to add identity_b to loss , otherwise it is zero') parser.add_argument('--l0_reg', action='store_true', help='if need to add lo_reg to loss , otherwise it is zero') parser.add_argument('--try_a', action='store_true', help='if need to add lo_reg to loss , otherwise it is zero') parser.add_argument('--contextual_loss', action='store_true', help='enable contextual loss') return parser def initialize(self, opt): BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B' , 'L0_reg' , 'scale_G_A' , 'scale_G_B' , 'contextual'] # specify the images you want to save/display. The program will call base_model.get_current_visuals visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_A') visual_names_B.append('idt_B') self.visual_names = visual_names_A + visual_names_B # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids , is_a = True) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids , is_a = False) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionScale = torch.nn.L1Loss() #self.l0norm = torch.sum() # torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, opt.beta2)) if self.opt.contextual_loss: self.optimizer_G_contextual = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr*25, betas=(opt.beta1, opt.beta2)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) if self.opt.contextual_loss: self.optimizers.append(self.optimizer_G_contextual) # self.grad_conv_filter = torch.Tensor([[[0, 1, 0], [1, -4, 1], [0, 1, 0]], [[0, 1, 0], [1, -4, 1], [0, 1, 0]], # [[0, 1, 0], [1, -4, 1], [0, 1, 0]]]).cuda().unsqueeze(0) #TODO put an if condition to use cuda only when --gpu_ids is >= 0 if(len(opt.gpu_ids) > 0): self.grad_conv_filter = torch.Tensor([ [ [0, 1, 0], [1, -4, 1], [0, 1, 0]] ]).cuda().unsqueeze(0) else: self.grad_conv_filter = torch.Tensor([ [ [0, 1, 0], [1, -4, 1], [0, 1, 0]] ]).unsqueeze(0) #self.grad_conv_filter = torch.Tensor([ [ [0, 1, 0], [1, -4, 1], [0, 1, 0]] ]).cuda().unsqueeze(0) if (self.opt.contextual_loss and len(opt.gpu_ids) > 0): self.vgg19 = VGG19(vgg19_path).cuda() self.vgg19.eval() elif (self.opt.contextual_loss): self.vgg19 = VGG19(vgg19_path) self.vgg19.eval() def set_input(self, input): AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.fake_B = self.netG_A(self.real_A) #self.rec_A = self.netG_B(self.fake_B) self.rec_A = self.real_A # if self.opt.try_a: # self.fake_A = self.real_B # self.rec_B = self.real_B # else: # self.fake_A = self.netG_B(self.real_B) # self.rec_B = self.netG_A(self.fake_A) self.fake_A = self.netG_B(self.real_B) self.rec_B = self.netG_A(self.fake_A) def calc_scale_loss(self, real,fake): list_of_scale = [1,2,4,8,16] scale_factor = [0.0001,0.001,0.01,0.1,1] #list_of_scale = [1] #scale_factor = [1] _ , __, orig_w, orig_h = real.shape loss_scale = 0 for index, scale in enumerate(list_of_scale): scaled_w = int( orig_w / scale ) scaled_h = int( orig_h / scale ) scaled_real = F.adaptive_avg_pool3d(self.rgb2gray(real),(1,scaled_w,scaled_h)) #.resize((scaled_w, scaled_h), Image.BICUBIC) scaled_fake = F.adaptive_avg_pool3d(self.rgb2gray(fake),(1,scaled_w,scaled_h)) #.resize((scaled_w, scaled_h), Image.BICUBIC) grad_scaled_real = F.conv2d(scaled_real, self.grad_conv_filter, padding=1) #TODO padding grad_scaled_fake = F.conv2d(scaled_fake, self.grad_conv_filter, padding=1) # TODO padding # my_filter = torch.Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]]).cuda().unsqueeze(0) # image_filter = F.conv2d(scaled_real, my_filter, padding=1) # scaleed = image_filter / 9 # use_filter = (scaleed < 0.3).type(torch.cuda.FloatTensor) # white = scaleed * use_filter + 1 * (1-use_filter) # white = 1 - white #grad_scaled_real.required_grad = False #curr_loss = scale_factor[index] * self.criterionScale(grad_scaled_fake * white, grad_scaled_real * white) curr_loss = scale_factor[index] * self.criterionScale(grad_scaled_fake , grad_scaled_real ) loss_scale += curr_loss #TODO factor (best for now it's 10) #self.save_image2(grad_scaled_fake) return loss_scale #self.save_image2(white) def calc_contextual_loss(self,generated_image,target_image): generated_image_vgg19_layers = self.vgg19(generated_image) with torch.no_grad(): target_image_vgg19_layers = self.vgg19(target_image) loss = 0 lambdas = [] # scaling parameters for vgg layers num_elements = [] # number of elements in each vgg layer for img_layer in generated_image_vgg19_layers: num_elem = np.prod(img_layer.size()[1:]) num_elements.append(num_elem) lambdas.append(1.0/num_elem) lambdas = [lamda / sum(lambdas) for lamda in lambdas] for i in range( len(generated_image_vgg19_layers) ): loss += lambdas[i].__float__() * CX(generated_image_vgg19_layers[i] , target_image_vgg19_layers[i]) return loss def save_image2(self,output,file_name = 'yoav.png'): b, c, w, h = output.shape output = torch.clamp((output + 1)/2, 0, 1) #output = self.gray2rgb(output) output = output.permute(0,2,3,1)[0,:,:,:] I = output.data[:,:,0] I = I.cpu().numpy() I8 = (((I - I.min()) / (I.max() - I.min())) * 255.9).astype(np.uint8) img = Image.fromarray(I8) img.save(file_name) #picture = np.zeros((w, h, 3)) # picture[:, :, 0] = output.data[0,0:1, :, :] # picture[:, :, 1] = output.data[0,1:2, :, :] # picture[:, :, 2] = output.data[0,2:3, :, :] #plt.imshow(output.data[:,:,0]) #plt.savefig(file_name) def save_image(self,output,file_name = 'yoav.png'): #self.save_image(self.real_B[1,:,:,:].squeeze(0),'real.png') __, w, h = output.shape output = torch.clamp(output + 0.5, 0, 1) picture = np.zeros((w, h, 3)) picture[:, :, 0] = output.data[0, :, :] picture[:, :, 1] = output.data[1, :, :] picture[:, :, 2] = output.data[2, :, :] plt.imshow(picture.data) plt.savefig(file_name) # grad = torch.abs(grad_scaled_real - grad_scaled_fake) # _, __, w, h = grad.shape # output = torch.clamp(grad, 0, 1) # picture = np.zeros((w, h, 3)) # output = output.squeeze(0) # picture[:, :, 0] = output.data[0, :, :] # picture[:, :, 1] = output.data[0, :, :] # picture[:, :, 2] = output.data[0, :, :] # plt.imshow(picture.data) # file_name = 'yoav.png' # plt.savefig(file_name) #self.save_image(real.squeeze(0)) # def backward_D_new(self,netD,fake): # # pred_fake = netD(fake) # loss_D_fake = self.criterionGAN(pred_fake, False) # # Combined loss # loss_D = (loss_D_fake) * 0.5 # # backward # loss_D.backward() # # return loss_D def backward_D_basic(self, netD, real, fake,g_real = None, do_another_value = False): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) if (do_another_value): #TODO loss_D_real2 = self.criterionGAN(netD(g_real.detach()), True) #TODO make sure it helps else: loss_D_real2 = 0 # Combined loss loss_D = (loss_D_real + loss_D_real2) * 0.5 + loss_D_fake # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, self.idt_A, do_another_value=True) * self.opt.lambda_D_A * 0.5 def backward_D_B(self): ##think fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) # if self.opt.try_a: # self.loss_D_B = 0 # else: # fake_A = self.fake_A_pool.query(self.fake_A) # self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G_contextual_loss(self): self.loss_contextual = self.calc_contextual_loss(self.fake_B, self.real_B) #TODO consider add for input image # combined loss self.loss_G = self.loss_contextual # self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_L0_reg + self.loss_scale self.loss_G.backward() def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. if self.opt.try_a == False and self.opt.no_identity_b: self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else : self.idt_B = self.real_A self.loss_idt_B = 0 #identical b removed because we expect the noise generator won't output same noisy input else: self.loss_idt_A = 0 self.loss_idt_B = 0 if (self.opt.l0_reg) : #self.loss_L0_reg = 0.000003 * torch.sum(-1 * torch.clamp(self.fake_B,-0.5,0.5) + 0.5) # TODO add parameter image = -1 * torch.clamp( self.rgb2gray( self.fake_B), -0.5, 0.5) + 0.5 mask_toward_zero = image.clone() mask_toward_one = image.clone() mask_toward_zero[mask_toward_zero > 0.5] = 0 mask_toward_one[mask_toward_one < 0.5] = 1 self.loss_L0_reg = 0.0001 *( torch.sum( mask_toward_zero ) + torch.sum( 1 - mask_toward_one ) ) # TODO add parameter else: self.loss_L0_reg = 0 self.loss_scale_G_A = self.opt.lambda_scale_G_A * self.calc_scale_loss(self.real_A,self.fake_B) self.loss_scale_G_B = self.opt.lambda_scale_G_B * self.calc_scale_loss(self.real_B, self.fake_A) # GAN loss D_A(G_A(A)) self.loss_G_A = ( self.criterionGAN(self.netD_A(self.fake_B), True) )* self.opt.lambda_G_A + self.criterionGAN(self.netD_A(self.idt_A), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) #think # if self.opt.try_a: # self.loss_G_B = 0 # else: # self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) #Forward cycle loss # self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # # Backward cycle loss # if self.opt.try_a: # self.loss_cycle_B = 0 # else: # self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # Forward cycle loss self.loss_cycle_A = 0 self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_L0_reg #+ self.loss_scale_G_A + self.loss_scale_G_B self.loss_G.backward() def rgb2gray(self, rgb): r, g, b = rgb[:,0:1, :, :], rgb[:,1:2, :, :], rgb[:,2:3, :, :] gray = 0.2989 * r + 0.5870 * g + 0.1140 * b return gray def gray2rgb(self, gray): batch, c, w, h = gray.shape if (c == 1): #print("convert to rgb") r, g, b = 0.2989 * gray, 0.5870 * gray, 0.1140 * gray rgb = torch.zeros((batch , 3,w, h )).cuda() rgb[:, 0:1, :, :] = r rgb[:, 1:2, :, :] = g rgb[:, 2:3, :, :] = b return rgb else: return gray def print_layer_grad(self, initial_print): model = self.netG_A modules_list = list(model.modules()) layer_list = [x for x in modules_list if isinstance(x, torch.nn.Conv2d) or isinstance(x, torch.nn.Linear) ] grad_list = [] for layer in layer_list: if isinstance(layer, torch.nn.Conv2d) or isinstance(layer, torch.nn.Linear): grad_list.append( torch.norm(layer._parameters['weight'].grad, 1).__float__() ) print (initial_print , grad_list) def optimize_parameters(self): # forward self.forward() # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) self.optimizer_G.zero_grad() self.backward_G() #self.print_layer_grad("regular ") self.optimizer_G.step() # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() self.backward_D_A() self.backward_D_B() self.optimizer_D.step() self.loss_contextual = 0 if self.opt.contextual_loss: torch.cuda.empty_cache() self.fake_B = self.netG_A(self.real_A) self.optimizer_G_contextual.zero_grad() self.backward_G_contextual_loss() #self.print_layer_grad("con ") self.optimizer_G_contextual.step()