Created
February 5, 2019 20:33
-
-
Save kastentx/60572989c71daa2124a407179eaffe7c to your computer and use it in GitHub Desktop.
An excerpt from Leon Gatys' Jupyter Notebook on style transfer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #define layers, loss functions, weights and compute optimization targets | |
| style_layers = ['stack1_layer1','stack2_layer1','stack3_layer1','stack4_layer1', 'stack5_layer1'] | |
| content_layers = ['stack4_layer2'] | |
| loss_layers = style_layers + content_layers | |
| loss_fns = [GramMSELoss()] * len(style_layers) + [nn.MSELoss()] * len(content_layers) | |
| if torch.cuda.is_available(): | |
| loss_fns = [loss_fn.cuda() for loss_fn in loss_fns] | |
| #these are good weights settings: | |
| style_weights = [1e3/n**2 for n in [64,128,256,512,512]] | |
| content_weights = [1e0] | |
| weights = style_weights + content_weights | |
| #compute optimization targets | |
| style_targets = [GramMatrix()(A).detach() for A in vgg(style_image, style_layers)] | |
| content_targets = [A.detach() for A in vgg(content_image, content_layers)] | |
| targets = style_targets + content_targets |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment