#!/usr/bin/python3 # Sample GAN implementation for learning purposes only # The network will model 2 different dataset of different distribution and # train Generator to approximate the Distributor's distribution better import numpy as np import torch import torch.nn as nn import torch.optim as optim # load ui libs seaborn_available, matplotlib_available = True, True try: import seaborn as sns except ImportError: seaborn_available = False try: import matplotlib as mpl import matplotlib.pyplot as plt mpl.use('tkagg') # fix possible segfault in macos except ImportError: matplotlib_available = False # cuda device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # target distribution mean, stddev = 4, 1.25 # hyperparameters g_in_size = 1 g_hidden_size = 10 # generator complexity g_out_size = 1 d_in_size = 500 d_hidden_size = 20 # discriminator complexity d_out_size = 1 # binary output, either true or false minibatch_size = d_in_size g_learning_rate = 1e-3 d_learning_rate = 1e-3 sgd_momentum = 0.9 num_epochs = 10000 num_steps = 20 print_interval = 100 # activation functions g_act = torch.tanh d_act = torch.sigmoid # network class Generator(nn.Module): def __init__(self, in_features, hidden_features, out_features, f): super(Generator, self).__init__() self.m1 = nn.Linear(in_features, hidden_features) self.m2 = nn.Linear(hidden_features, hidden_features) self.m3 = nn.Linear(hidden_features, out_features) self.f = f def forward(self, x): x = self.f(self.m1(x)) x = self.f(self.m2(x)) return self.m3(x) class Discriminator(nn.Module): def __init__(self, in_features, hidden_features, out_features, f): super(Discriminator, self).__init__() self.m1 = nn.Linear(in_features, hidden_features) self.m2 = nn.Linear(hidden_features, hidden_features) self.m3 = nn.Linear(hidden_features, out_features) self.f = f def forward(self, x): x = self.f(self.m1(x)) x = self.f(self.m2(x)) return self.f(self.m3(x)) def get_gaussian_distribution(mu, sigma): return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n))) def get_uniform_distribution(): return lambda m, n: torch.rand(m, n) def tensor_to_list(v): return v.data.storage().tolist() def stats(d): return [round(np.mean(d), 4), round(np.std(d), 4)] def histplot(data, label="", bins=25): if matplotlib_available: ax = sns.histplot(data, label=label, bins=bins) legend = plt.legend() plt.show() else: print("Matplotlib not available") if __name__ == "__main__": D_sampler = get_gaussian_distribution(mean, stddev) # (1, 500) G_sampler = get_uniform_distribution() # (500, 1) # Uncomment to visualize input distribution # if seaborn_available: # histplot(D_sampler(d_in_size).numpy()[0], label="Exampel Discriminator Input") # histplot(G_sampler(minibatch_size, g_in_size).numpy(), label="Exampel Generator Input") G = Generator(g_in_size, g_hidden_size, g_out_size, g_act) D = Discriminator(d_in_size, d_hidden_size, d_out_size, d_act) G = G.to(device) D = D.to(device) # error rates D_real_err_rate, D_fake_err_rate, G_err_rate = 0, 0, 0 # loss & optimizers criterion = nn.BCELoss() G_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum) D_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum) for epoch in range(num_epochs): for d_step in range(num_steps): D.zero_grad() # train real input on Discriminator D_real_x = D_sampler(d_in_size).to(device) D_real_y = D(D_real_x) D_real_err = criterion(D_real_y, torch.ones([1, 1], device=device)) # 1 - true D_real_err.backward() # train fake input on Discriminator D_fake_input = G_sampler(d_in_size, g_in_size).to(device) D_fake_x = G(D_fake_input).detach() D_fake_y = D(D_fake_x.t()) D_fake_err = criterion(D_fake_y, torch.zeros([1, 1], device=device)) # 0 - fake D_fake_err.backward() D_optimizer.step() # update discriminator error rates D_real_err_rate = tensor_to_list(D_real_err)[0] D_fake_err_rate = tensor_to_list(D_fake_err)[0] # train Generator with the Discriminator's responses, but do not update Discriminator for g_step in range(num_steps): G.zero_grad() G_input = G_sampler(minibatch_size, g_in_size).to(device) G_fake_x = G(G_input) G_fake_y = D(G_fake_x.t()) G_err = criterion(G_fake_y, torch.ones([1, 1], device=device)) G_err.backward() G_optimizer.step() G_err_rate = tensor_to_list(G_err)[0] if epoch % print_interval == 0: print(f"Epoch: {epoch} Loss (DRE, DFE, GE): {round(D_real_err_rate, 4)} {round(D_fake_err_rate, 4)} {round(G_err_rate, 4)} Dist (Real | Fake): {stats(tensor_to_list(D_real_x))} {stats(tensor_to_list(D_fake_x))}")