Created
January 2, 2020 08:22
-
-
Save VyBui/2e300e57b2860f3ca0360d341c39fefc to your computer and use it in GitHub Desktop.
Revisions
-
VyBui created this gist
Jan 2, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,156 @@ from __future__ import absolute_import, division, print_function, unicode_literals import logging.config import tensorflow as tf import os import numpy as np from tensorflow import ConfigProto from argument_parse import args from module.discriminator_vgg19 import Discriminator from module.generator import Generator from module.losses import l1_loss from config import cfg from data_tools.parse_records_dataset import input_fn from calculate_average_gradients import get_perturbed_batch, average_gradients from preprocessing.dataset import pre_processing from utils import shuffle if args.mode in ['train', 'test', 'val']: params = {'batch_size': cfg.train_batch_size, 'tfrecords_path': cfg.tfrecords_path} train_dataset = input_fn(args.mode, params) else: raise ValueError("mode must be via ( train, test or val).") if not any([isinstance(args.num_gpus, int), isinstance(args.batch_size, int)]): raise ValueError("num gpus or batch size must be type integer.") if args.mode in ['train', 'test', 'val']: params = {'batch_size': cfg.train_batch_size, 'tfrecords_path': cfg.tfrecords_path} train_dataset = input_fn(args.mode, params) # get TF logger # load logging confoguration and create log object logging.config.fileConfig('logging.conf') logging.basicConfig(filename='skin_generator.log', level=logging.DEBUG) log = logging.getLogger('TensorFlow') log.setLevel(logging.DEBUG) # create formatter and add it to the handlers formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') # create file handler which logs even debug messages fh = logging.FileHandler('module_2.log') fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) log.addHandler(fh) if __name__ == '__main__': train_iterator = train_dataset.make_initializable_iterator() batch_data = train_iterator.get_next() image_label, body_parts, seg_parts, top_and_bottom = batch_data train_batch_size_step = args.batch_size // args.num_gpus # output of D for real images D_real, D_real_logits = Discriminator(image_label).feed_forward() # output of D for fake images gen, end_points = Generator(body_parts, seg_parts, top_and_bottom).feed_forward() D_fake, D_fake_logits = Discriminator(gen).feed_forward() label_input_perturbed = get_perturbed_batch = get_perturbed_batch(image_label) # get loss for discriminator with tf.name_scope('D_loss'): d_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real))) d_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros_like(D_fake))) d_loss = d_loss_real + d_loss_fake alpha = tf.random_uniform(shape=tf.shape(image_label), minval=0., maxval=1.) differences = label_input_perturbed - image_label # This is different from WGAN-GP interpolates = image_label + (alpha * differences) _, D_inter = Discriminator(interpolates).feed_forward() gradients = tf.gradients(D_inter, [interpolates])[0] slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) lambd = 0.1 d_loss += lambd * gradient_penalty # get loss for generator with tf.name_scope('G_loss'): g_mse_lambda = 100 g_mse_loss = tf.keras.losses.MSE(y_true=image_label, y_pred=gen) g_mse_loss = g_mse_loss * g_mse_lambda gen_loss = g_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones_like(D_fake))) + g_mse_loss # Training: divide trainable variables into a group for D and a group for G t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if "discriminator" in var.name] g_vars = [var for var in t_vars if "generator" in var.name] # Optimizers with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): d_train_opt = tf.train.AdamOptimizer(learning_rate=cfg.lr, beta1=0.5).minimize(d_loss, var_list=d_vars) g_train_opt = tf.train.AdamOptimizer(learning_rate=cfg.lr, beta1=0.5).minimize(g_loss, var_list=g_vars) top = top_and_bottom[:, :, :, 0:3] # Summary gen__image_sum = tf.summary.image("fake", gen[:, :, :, ::-1], max_outputs=1) real_image_sum = tf.summary.image("real", image_label[:, :, :, ::-1], max_outputs=1) top_sum = tf.summary.image("top", top[:, :, :, ::-1], max_outputs=1) d_loss_real_sum = tf.summary.scalar("d_loss_real", d_loss_real, family="D_loss") d_loss_fake_sum = tf.summary.scalar("d_loss_fake", d_loss_fake, family="D_loss") d_loss_sum = tf.summary.scalar("d_loss", d_loss, family="D_loss") g_loss_l1_sum = tf.summary.scalar("g_mse_loss", g_mse_loss, family="G_loss") g_loss_sum = tf.summary.scalar("g_loss", g_loss, family="G_loss") # final summary operations g_sum = tf.summary.merge([d_loss_fake_sum, g_loss_sum, g_loss_l1_sum, gen__image_sum, real_image_sum, top_sum]) d_sum = tf.summary.merge([d_loss_real_sum, d_loss_sum, gen__image_sum, real_image_sum]) sess = tf.Session() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord) summary_writer = tf.summary.FileWriter(cfg.log_dir, sess.graph) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) sess.run(train_iterator.initializer) saver.restore(sess, tf.train.latest_checkpoint(cfg.path_save_model)) print("restore successfully !! " * 100) try: for epoch in range(cfg.epoch_size): # Training for itr in range(cfg.dataset_size // cfg.train_batch_size): # noise_label = get_perturbed_batch(image_label) # Update Dicriminator d_loss_val, summary_str, opt_d = sess.run([d_loss, d_sum, d_train_opt]) # Update Generator g_loss_val, summary_str, opt_g = sess.run([g_loss, g_sum, g_train_opt]) if itr % 50 == 0: print("epoch - {} | iter - {} | d-loss - {}".format(epoch, itr, d_loss_val)) summary_writer.add_summary(summary_str, itr) print("epoch - {} | iter - {} | g-loss - {}".format(epoch, itr, g_loss_val)) summary_writer.add_summary(summary_str, itr) # summary_writer.add_summary(clothes_sumarry, itr) saver.save(sess, cfg.path_save_model) print("Successful !!!") except Exception as es: log.debug(es) pass