# tensorflow/lucid CPPN (X,Y) --> (R,G,B) Differentiable Image Parameterization in PyTorch import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from inception5h import Inception5h from PIL import Image def main(): # Setup params image_size=(224,224) iterations = 512 lr = 0.005 use_device = 'cuda:0' # Uncomment to use seed #set_seed(876) # Target layer and channel layer_name='mixed4b_3x3_pre_relu_conv' channel=77 # Load InceptionV1 cnn = Inception5h() cnn.load_state_dict(torch.load('inception5h.pth')) cnn.add_layers() # Add layers not found in model file net = cnn.to(use_device).eval() for param in net.parameters(): param.requires_grad = False # Uncomment to print all hookable layers #print(get_hookable_layers(cnn)) # Create instance of CPPN img_cppn = CPPN_Conv(size=image_size, num_channels=24, num_layers=8, use_device=use_device, normalize=False) print('CPPN Params Count', sum(p.numel() for p in img_cppn.parameters() if p.requires_grad)) loss_modules = register_hook_module(net, layer_name, channel) # Create 224x224 image output_tensor, img_cppn = dream_cppn(net, img_cppn, iterations, lr, loss_modules, use_device) simple_deprocess(output_tensor, name='out.png') # Create 720x720 image using multiscale generation image_size=(720,720) img_cppn.remake_input(image_size, use_device=use_device) output_tensor, img_cppn = dream_cppn(net, img_cppn, iterations, lr, loss_modules, use_device) simple_deprocess(output_tensor, name='out.png') # Simple Deprocess def simple_deprocess(output_tensor, name): output_tensor = output_tensor.squeeze(0).cpu() /255 output_tensor.clamp_(0, 1) Image2PIL = transforms.ToPILImage() image = Image2PIL(output_tensor.cpu()) image.save(name) # Function to maximize CNN activations def dream_cppn(net, img_cppn, iterations, lr, loss_modules, use_device): # Setup optimizer to optimize CPPN instance optimizer = torch.optim.Adam(img_cppn.parameters(), lr=lr) # Training loop for i in range(iterations): optimizer.zero_grad() img = img_cppn() * 255 # Create RGB image with CPPN net(img) # Create loss values loss = 0 for mod in loss_modules: # Collect loss values loss += mod.loss loss.backward() # Uncomment to save iterations #if i % 25 == 0: # simple_deprocess(img.detach(), 'out_'+str(i)+'.png') print('Iteration', str(i+1), 'Loss', str(loss.item())) optimizer.step() img = img_cppn() * 255 return img, img_cppn # Activation function for CPPN class CompositeActivation(nn.Module): def forward(self, input): input = torch.atan(input) return torch.cat([input / 0.67, (input * input) / 0.6], 1) # Compositional pattern-producing network (CPPN) with Conv2d layers class CPPN_Conv(nn.Module): def __init__(self, size=(224, 224), num_channels=24, num_layers=8, activ_func=CompositeActivation(), use_device='cpu', normalize=False): super(CPPN_Conv, self).__init__() self.input_size = size self.n_channels = num_channels self.net = self.create_net(num_channels, num_layers, activ_func, use_device, normalize) self.cppn_input = self.create_input(use_device) # Create CPPN (X,Y) --> (R,G,B) def create_net(self, num_channels, num_layers, activ_func, use_device, instance_norm, affine=True, bias=True): net = nn.Sequential() net.add_module(str(len(net)), nn.Conv2d(in_channels=2, out_channels=num_channels, kernel_size=1, bias=bias)) if instance_norm: net.add_module(str(len(net)), nn.InstanceNorm2d(num_channels, affine=affine)) net.add_module(str(len(net)), activ_func) for l in range(num_layers - 1): net.add_module(str(len(net)), nn.Conv2d(in_channels=num_channels*2, out_channels=num_channels, kernel_size=1, bias=bias)) if instance_norm: net.add_module(str(len(net)), nn.InstanceNorm2d(num_channels, affine=affine)) net.add_module(str(len(net)), activ_func) net.add_module(str(len(net)), nn.Conv2d(in_channels=num_channels*2, out_channels=3, kernel_size=1, bias=bias)) net.add_module(str(len(net)), nn.Sigmoid()) net.apply(self.cppn_normal) return net.to(use_device) # Create X,Y input for CPPN def create_input(self, use_device): if type(self.input_size) is not tuple and type(self.input_size) is not list: self.input_size = (self.input_size, self.input_size) w = torch.arange(0, self.input_size[1]).to(use_device) h = torch.arange(0, self.input_size[0]).to(use_device) w_exp = w.unsqueeze(1).expand((self.input_size[1], self.input_size[0])).true_divide(self.input_size[0]) - 0.5 h_exp = h.unsqueeze(0).expand((self.input_size[1], self.input_size[0])).true_divide(self.input_size[1]) - 0.5 return torch.stack((w_exp, h_exp), -1).permute(2,1,0).unsqueeze(0) # Replace input with different sized input def remake_input(self, image_size, use_device): self.input_size = image_size self.cppn_input = self.create_input(use_device) # Normalize Conv2d weights def cppn_normal(self, l): if type(l) == nn.Conv2d: l.weight.data.normal_(0, (1.0/self.n_channels)**(1/2)) if l.bias is not None: nn.init.zeros_(l.bias) def forward(self): return self.net(self.cppn_input) # Create loss module and hook def register_hook_module(net, layer_name, channel=-1): loss_module = SimpleDreamLossHook(channel) getattr(net, layer_name).register_forward_hook(loss_module) return [loss_module] # Define a simple forward hook to collect DeepDream loss class SimpleDreamLossHook(nn.Module): def __init__(self, channel=-1): super(SimpleDreamLossHook, self).__init__() self.get_loss = self.mean_loss self.channel = channel def mean_loss(self, input): return input.mean() def forward(self, module, input, output): if self.channel > -1: self.loss = -self.get_loss(output[:,self.channel,:,:]) else: self.loss = -self.get_loss(output) # Set global seeds to output reproducible def set_seed(seed): import random torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic=True random.seed(seed) # Get layers that can be hooked def get_hookable_layers(cnn): return [l[0] for l in list(cnn.named_children())] if __name__ == "__main__": main()