|
|
@@ -0,0 +1,644 @@ |
|
|
-- modified version of @jcjohnson's neural-style |
|
|
-- by @htoyryla 14 May 2017 |
|
|
-- allows exploring th effect of giving emphasis to specific channel(s) in a style layer |
|
|
-- use -style_layers to select layer as usual, using a single layer is recommended |
|
|
-- -style_channels to specify one or more channels to be given emphasis, other channels will be attenuated |
|
|
-- if a channel does not exist in a layer, it is ignored and given a warning |
|
|
-- note: if multiple style layers are used, style_channels setting affects both |
|
|
|
|
|
require 'torch' |
|
|
require 'nn' |
|
|
require 'image' |
|
|
require 'optim' |
|
|
|
|
|
require 'loadcaffe' |
|
|
|
|
|
|
|
|
local cmd = torch.CmdLine() |
|
|
|
|
|
-- Basic options |
|
|
cmd:option('-style_image', 'examples/inputs/seated-nude.jpg', |
|
|
'Style target image') |
|
|
cmd:option('-style_blend_weights', 'nil') |
|
|
cmd:option('-content_image', 'examples/inputs/tubingen.jpg', |
|
|
'Content target image') |
|
|
cmd:option('-image_size', 512, 'Maximum height / width of generated image') |
|
|
cmd:option('-gpu', '0', 'Zero-indexed ID of the GPU to use; for CPU mode set -gpu = -1') |
|
|
cmd:option('-multigpu_strategy', '', 'Index of layers to split the network across GPUs') |
|
|
|
|
|
-- Optimization options |
|
|
cmd:option('-content_weight', 5e0) |
|
|
cmd:option('-style_weight', 1e2) |
|
|
cmd:option('-tv_weight', 1e-3) |
|
|
cmd:option('-num_iterations', 1000) |
|
|
cmd:option('-normalize_gradients', false) |
|
|
cmd:option('-init', 'random', 'random|image') |
|
|
cmd:option('-init_image', '') |
|
|
cmd:option('-optimizer', 'lbfgs', 'lbfgs|adam') |
|
|
cmd:option('-learning_rate', 1e1) |
|
|
cmd:option('-lbfgs_num_correction', 0) |
|
|
|
|
|
-- Output options |
|
|
cmd:option('-print_iter', 50) |
|
|
cmd:option('-save_iter', 100) |
|
|
cmd:option('-output_image', 'out.png') |
|
|
|
|
|
-- Other options |
|
|
cmd:option('-style_scale', 1.0) |
|
|
cmd:option('-original_colors', 0) |
|
|
cmd:option('-pooling', 'max', 'max|avg') |
|
|
cmd:option('-proto_file', 'models/VGG_ILSVRC_19_layers_deploy.prototxt') |
|
|
cmd:option('-model_file', 'models/VGG_ILSVRC_19_layers.caffemodel') |
|
|
cmd:option('-backend', 'nn', 'nn|cudnn|clnn') |
|
|
cmd:option('-cudnn_autotune', false) |
|
|
cmd:option('-seed', -1) |
|
|
|
|
|
cmd:option('-content_layers', 'relu4_2', 'layers for content') |
|
|
cmd:option('-style_layers', 'relu1_1,relu2_1,relu3_1,relu4_1,relu5_1', 'layers for style') |
|
|
cmd:option('-style_channels', 'all', 'channels for style') |
|
|
|
|
|
local style_channels = nil |
|
|
|
|
|
local function main(params) |
|
|
local dtype, multigpu = setup_gpu(params) |
|
|
|
|
|
local loadcaffe_backend = params.backend |
|
|
if params.backend == 'clnn' then loadcaffe_backend = 'nn' end |
|
|
local cnn = loadcaffe.load(params.proto_file, params.model_file, loadcaffe_backend):type(dtype) |
|
|
|
|
|
local content_image = image.load(params.content_image, 3) |
|
|
content_image = image.scale(content_image, params.image_size, 'bilinear') |
|
|
local content_image_caffe = preprocess(content_image):float() |
|
|
|
|
|
local style_size = math.ceil(params.style_scale * params.image_size) |
|
|
local style_image_list = params.style_image:split(',') |
|
|
local style_images_caffe = {} |
|
|
for _, img_path in ipairs(style_image_list) do |
|
|
local img = image.load(img_path, 3) |
|
|
img = image.scale(img, style_size, 'bilinear') |
|
|
local img_caffe = preprocess(img):float() |
|
|
table.insert(style_images_caffe, img_caffe) |
|
|
end |
|
|
|
|
|
local init_image = nil |
|
|
if params.init_image ~= '' then |
|
|
init_image = image.load(params.init_image, 3) |
|
|
local H, W = content_image:size(2), content_image:size(3) |
|
|
init_image = image.scale(init_image, W, H, 'bilinear') |
|
|
init_image = preprocess(init_image):float() |
|
|
end |
|
|
|
|
|
if params.style_channels ~= 'all' then |
|
|
style_channels = params.style_channels:split(',') |
|
|
for i=1,#style_channels do |
|
|
style_channels[i] = tonumber(style_channels[i]) |
|
|
end |
|
|
end |
|
|
|
|
|
-- Handle style blending weights for multiple style inputs |
|
|
local style_blend_weights = nil |
|
|
if params.style_blend_weights == 'nil' then |
|
|
-- Style blending not specified, so use equal weighting |
|
|
style_blend_weights = {} |
|
|
for i = 1, #style_image_list do |
|
|
table.insert(style_blend_weights, 1.0) |
|
|
end |
|
|
else |
|
|
style_blend_weights = params.style_blend_weights:split(',') |
|
|
assert(#style_blend_weights == #style_image_list, |
|
|
'-style_blend_weights and -style_images must have the same number of elements') |
|
|
end |
|
|
-- Normalize the style blending weights so they sum to 1 |
|
|
local style_blend_sum = 0 |
|
|
for i = 1, #style_blend_weights do |
|
|
style_blend_weights[i] = tonumber(style_blend_weights[i]) |
|
|
style_blend_sum = style_blend_sum + style_blend_weights[i] |
|
|
end |
|
|
for i = 1, #style_blend_weights do |
|
|
style_blend_weights[i] = style_blend_weights[i] / style_blend_sum |
|
|
end |
|
|
|
|
|
local content_layers = params.content_layers:split(",") |
|
|
local style_layers = params.style_layers:split(",") |
|
|
|
|
|
-- Set up the network, inserting style and content loss modules |
|
|
local content_losses, style_losses = {}, {} |
|
|
local next_content_idx, next_style_idx = 1, 1 |
|
|
local net = nn.Sequential() |
|
|
if params.tv_weight > 0 then |
|
|
local tv_mod = nn.TVLoss(params.tv_weight):type(dtype) |
|
|
net:add(tv_mod) |
|
|
end |
|
|
for i = 1, #cnn do |
|
|
if next_content_idx <= #content_layers or next_style_idx <= #style_layers then |
|
|
local layer = cnn:get(i) |
|
|
local name = layer.name |
|
|
local layer_type = torch.type(layer) |
|
|
local is_pooling = (layer_type == 'cudnn.SpatialMaxPooling' or layer_type == 'nn.SpatialMaxPooling') |
|
|
if is_pooling and params.pooling == 'avg' then |
|
|
assert(layer.padW == 0 and layer.padH == 0) |
|
|
local kW, kH = layer.kW, layer.kH |
|
|
local dW, dH = layer.dW, layer.dH |
|
|
local avg_pool_layer = nn.SpatialAveragePooling(kW, kH, dW, dH):type(dtype) |
|
|
local msg = 'Replacing max pooling at layer %d with average pooling' |
|
|
print(string.format(msg, i)) |
|
|
net:add(avg_pool_layer) |
|
|
else |
|
|
net:add(layer) |
|
|
end |
|
|
if name == content_layers[next_content_idx] then |
|
|
print("Setting up content layer", i, ":", layer.name) |
|
|
local norm = params.normalize_gradients |
|
|
local loss_module = nn.ContentLoss(params.content_weight, norm):type(dtype) |
|
|
net:add(loss_module) |
|
|
table.insert(content_losses, loss_module) |
|
|
next_content_idx = next_content_idx + 1 |
|
|
end |
|
|
if name == style_layers[next_style_idx] then |
|
|
print("Setting up style layer ", i, ":", layer.name) |
|
|
local norm = params.normalize_gradients |
|
|
local loss_module = nn.StyleLoss(params.style_weight, norm):type(dtype) |
|
|
net:add(loss_module) |
|
|
table.insert(style_losses, loss_module) |
|
|
next_style_idx = next_style_idx + 1 |
|
|
end |
|
|
end |
|
|
end |
|
|
if multigpu then |
|
|
net = setup_multi_gpu(net, params) |
|
|
end |
|
|
net:type(dtype) |
|
|
|
|
|
-- Capture content targets |
|
|
for i = 1, #content_losses do |
|
|
content_losses[i].mode = 'capture' |
|
|
end |
|
|
print 'Capturing content targets' |
|
|
print(net) |
|
|
content_image_caffe = content_image_caffe:type(dtype) |
|
|
net:forward(content_image_caffe:type(dtype)) |
|
|
|
|
|
-- Capture style targets |
|
|
for i = 1, #content_losses do |
|
|
content_losses[i].mode = 'none' |
|
|
end |
|
|
for i = 1, #style_images_caffe do |
|
|
print(string.format('Capturing style target %d', i)) |
|
|
for j = 1, #style_losses do |
|
|
style_losses[j].mode = 'capture' |
|
|
style_losses[j].blend_weight = style_blend_weights[i] |
|
|
end |
|
|
net:forward(style_images_caffe[i]:type(dtype)) |
|
|
end |
|
|
|
|
|
-- Set all loss modules to loss mode |
|
|
for i = 1, #content_losses do |
|
|
content_losses[i].mode = 'loss' |
|
|
end |
|
|
for i = 1, #style_losses do |
|
|
style_losses[i].mode = 'loss' |
|
|
end |
|
|
|
|
|
-- We don't need the base CNN anymore, so clean it up to save memory. |
|
|
cnn = nil |
|
|
for i=1, #net.modules do |
|
|
local module = net.modules[i] |
|
|
if torch.type(module) == 'nn.SpatialConvolutionMM' then |
|
|
-- remove these, not used, but uses gpu memory |
|
|
module.gradWeight = nil |
|
|
module.gradBias = nil |
|
|
end |
|
|
end |
|
|
collectgarbage() |
|
|
|
|
|
-- Initialize the image |
|
|
if params.seed >= 0 then |
|
|
torch.manualSeed(params.seed) |
|
|
end |
|
|
local img = nil |
|
|
if params.init == 'random' then |
|
|
img = torch.randn(content_image:size()):float():mul(0.001) |
|
|
elseif params.init == 'image' then |
|
|
if init_image then |
|
|
img = init_image:clone() |
|
|
else |
|
|
img = content_image_caffe:clone() |
|
|
end |
|
|
else |
|
|
error('Invalid init type') |
|
|
end |
|
|
img = img:type(dtype) |
|
|
|
|
|
-- Run it through the network once to get the proper size for the gradient |
|
|
-- All the gradients will come from the extra loss modules, so we just pass |
|
|
-- zeros into the top of the net on the backward pass. |
|
|
local y = net:forward(img) |
|
|
local dy = img.new(#y):zero() |
|
|
|
|
|
-- Declaring this here lets us access it in maybe_print |
|
|
local optim_state = nil |
|
|
if params.optimizer == 'lbfgs' then |
|
|
optim_state = { |
|
|
maxIter = params.num_iterations, |
|
|
verbose=true, |
|
|
tolX=-1, |
|
|
tolFun=-1, |
|
|
} |
|
|
if params.lbfgs_num_correction > 0 then |
|
|
optim_state.nCorrection = params.lbfgs_num_correction |
|
|
end |
|
|
elseif params.optimizer == 'adam' then |
|
|
optim_state = { |
|
|
learningRate = params.learning_rate, |
|
|
} |
|
|
else |
|
|
error(string.format('Unrecognized optimizer "%s"', params.optimizer)) |
|
|
end |
|
|
|
|
|
local function maybe_print(t, loss) |
|
|
local verbose = (params.print_iter > 0 and t % params.print_iter == 0) |
|
|
if verbose then |
|
|
print(string.format('Iteration %d / %d', t, params.num_iterations)) |
|
|
for i, loss_module in ipairs(content_losses) do |
|
|
print(string.format(' Content %d loss: %f', i, loss_module.loss)) |
|
|
end |
|
|
for i, loss_module in ipairs(style_losses) do |
|
|
print(string.format(' Style %d loss: %f', i, loss_module.loss)) |
|
|
end |
|
|
print(string.format(' Total loss: %f', loss)) |
|
|
end |
|
|
end |
|
|
|
|
|
local function maybe_save(t) |
|
|
local should_save = params.save_iter > 0 and t % params.save_iter == 0 |
|
|
should_save = should_save or t == params.num_iterations |
|
|
if should_save then |
|
|
local disp = deprocess(img:double()) |
|
|
disp = image.minmax{tensor=disp, min=0, max=1} |
|
|
local filename = build_filename(params.output_image, t) |
|
|
if t == params.num_iterations then |
|
|
filename = params.output_image |
|
|
end |
|
|
|
|
|
-- Maybe perform postprocessing for color-independent style transfer |
|
|
if params.original_colors == 1 then |
|
|
disp = original_colors(content_image, disp) |
|
|
end |
|
|
|
|
|
image.save(filename, disp) |
|
|
end |
|
|
end |
|
|
|
|
|
-- Function to evaluate loss and gradient. We run the net forward and |
|
|
-- backward to get the gradient, and sum up losses from the loss modules. |
|
|
-- optim.lbfgs internally handles iteration and calls this function many |
|
|
-- times, so we manually count the number of iterations to handle printing |
|
|
-- and saving intermediate results. |
|
|
local num_calls = 0 |
|
|
local function feval(x) |
|
|
num_calls = num_calls + 1 |
|
|
net:forward(x) |
|
|
local grad = net:updateGradInput(x, dy) |
|
|
local loss = 0 |
|
|
for _, mod in ipairs(content_losses) do |
|
|
loss = loss + mod.loss |
|
|
end |
|
|
for _, mod in ipairs(style_losses) do |
|
|
loss = loss + mod.loss |
|
|
end |
|
|
maybe_print(num_calls, loss) |
|
|
maybe_save(num_calls) |
|
|
|
|
|
collectgarbage() |
|
|
-- optim.lbfgs expects a vector for gradients |
|
|
return loss, grad:view(grad:nElement()) |
|
|
end |
|
|
|
|
|
-- Run optimization. |
|
|
if params.optimizer == 'lbfgs' then |
|
|
print('Running optimization with L-BFGS') |
|
|
local x, losses = optim.lbfgs(feval, img, optim_state) |
|
|
elseif params.optimizer == 'adam' then |
|
|
print('Running optimization with ADAM') |
|
|
for t = 1, params.num_iterations do |
|
|
local x, losses = optim.adam(feval, img, optim_state) |
|
|
end |
|
|
end |
|
|
end |
|
|
|
|
|
|
|
|
function setup_gpu(params) |
|
|
local multigpu = false |
|
|
if params.gpu:find(',') then |
|
|
multigpu = true |
|
|
params.gpu = params.gpu:split(',') |
|
|
for i = 1, #params.gpu do |
|
|
params.gpu[i] = tonumber(params.gpu[i]) + 1 |
|
|
end |
|
|
else |
|
|
params.gpu = tonumber(params.gpu) + 1 |
|
|
end |
|
|
local dtype = 'torch.FloatTensor' |
|
|
if multigpu or params.gpu > 0 then |
|
|
if params.backend ~= 'clnn' then |
|
|
require 'cutorch' |
|
|
require 'cunn' |
|
|
if multigpu then |
|
|
cutorch.setDevice(params.gpu[1]) |
|
|
else |
|
|
cutorch.setDevice(params.gpu) |
|
|
end |
|
|
dtype = 'torch.CudaTensor' |
|
|
else |
|
|
require 'clnn' |
|
|
require 'cltorch' |
|
|
if multigpu then |
|
|
cltorch.setDevice(params.gpu[1]) |
|
|
else |
|
|
cltorch.setDevice(params.gpu) |
|
|
end |
|
|
dtype = torch.Tensor():cl():type() |
|
|
end |
|
|
else |
|
|
params.backend = 'nn' |
|
|
end |
|
|
|
|
|
if params.backend == 'cudnn' then |
|
|
require 'cudnn' |
|
|
if params.cudnn_autotune then |
|
|
cudnn.benchmark = true |
|
|
end |
|
|
cudnn.SpatialConvolution.accGradParameters = nn.SpatialConvolutionMM.accGradParameters -- ie: nop |
|
|
end |
|
|
return dtype, multigpu |
|
|
end |
|
|
|
|
|
|
|
|
function setup_multi_gpu(net, params) |
|
|
local DEFAULT_STRATEGIES = { |
|
|
[2] = {3}, |
|
|
} |
|
|
local gpu_splits = nil |
|
|
if params.multigpu_strategy == '' then |
|
|
-- Use a default strategy |
|
|
gpu_splits = DEFAULT_STRATEGIES[#params.gpu] |
|
|
-- Offset the default strategy by one if we are using TV |
|
|
if params.tv_weight > 0 then |
|
|
for i = 1, #gpu_splits do gpu_splits[i] = gpu_splits[i] + 1 end |
|
|
end |
|
|
else |
|
|
-- Use the user-specified multigpu strategy |
|
|
gpu_splits = params.multigpu_strategy:split(',') |
|
|
for i = 1, #gpu_splits do |
|
|
gpu_splits[i] = tonumber(gpu_splits[i]) |
|
|
end |
|
|
end |
|
|
assert(gpu_splits ~= nil, 'Must specify -multigpu_strategy') |
|
|
local gpus = params.gpu |
|
|
|
|
|
local cur_chunk = nn.Sequential() |
|
|
local chunks = {} |
|
|
for i = 1, #net do |
|
|
cur_chunk:add(net:get(i)) |
|
|
if i == gpu_splits[1] then |
|
|
table.remove(gpu_splits, 1) |
|
|
table.insert(chunks, cur_chunk) |
|
|
cur_chunk = nn.Sequential() |
|
|
end |
|
|
end |
|
|
table.insert(chunks, cur_chunk) |
|
|
assert(#chunks == #gpus) |
|
|
|
|
|
local new_net = nn.Sequential() |
|
|
for i = 1, #chunks do |
|
|
local out_device = nil |
|
|
if i == #chunks then |
|
|
out_device = gpus[1] |
|
|
end |
|
|
new_net:add(nn.GPU(chunks[i], gpus[i], out_device)) |
|
|
end |
|
|
|
|
|
return new_net |
|
|
end |
|
|
|
|
|
|
|
|
function build_filename(output_image, iteration) |
|
|
local ext = paths.extname(output_image) |
|
|
local basename = paths.basename(output_image, ext) |
|
|
local directory = paths.dirname(output_image) |
|
|
return string.format('%s/%s_%d.%s',directory, basename, iteration, ext) |
|
|
end |
|
|
|
|
|
|
|
|
-- Preprocess an image before passing it to a Caffe model. |
|
|
-- We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, |
|
|
-- and subtract the mean pixel. |
|
|
function preprocess(img) |
|
|
local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68}) |
|
|
local perm = torch.LongTensor{3, 2, 1} |
|
|
img = img:index(1, perm):mul(256.0) |
|
|
mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img) |
|
|
img:add(-1, mean_pixel) |
|
|
return img |
|
|
end |
|
|
|
|
|
|
|
|
-- Undo the above preprocessing. |
|
|
function deprocess(img) |
|
|
local mean_pixel = torch.DoubleTensor({103.939, 116.779, 123.68}) |
|
|
mean_pixel = mean_pixel:view(3, 1, 1):expandAs(img) |
|
|
img = img + mean_pixel |
|
|
local perm = torch.LongTensor{3, 2, 1} |
|
|
img = img:index(1, perm):div(256.0) |
|
|
return img |
|
|
end |
|
|
|
|
|
|
|
|
-- Combine the Y channel of the generated image and the UV channels of the |
|
|
-- content image to perform color-independent style transfer. |
|
|
function original_colors(content, generated) |
|
|
local generated_y = image.rgb2yuv(generated)[{{1, 1}}] |
|
|
local content_uv = image.rgb2yuv(content)[{{2, 3}}] |
|
|
return image.yuv2rgb(torch.cat(generated_y, content_uv, 1)) |
|
|
end |
|
|
|
|
|
|
|
|
-- Define an nn Module to compute content loss in-place |
|
|
local ContentLoss, parent = torch.class('nn.ContentLoss', 'nn.Module') |
|
|
|
|
|
function ContentLoss:__init(strength, normalize) |
|
|
parent.__init(self) |
|
|
self.strength = strength |
|
|
self.target = torch.Tensor() |
|
|
self.normalize = normalize or false |
|
|
self.loss = 0 |
|
|
self.crit = nn.MSECriterion() |
|
|
self.mode = 'none' |
|
|
end |
|
|
|
|
|
function ContentLoss:updateOutput(input) |
|
|
if self.mode == 'loss' then |
|
|
self.loss = self.crit:forward(input, self.target) * self.strength |
|
|
elseif self.mode == 'capture' then |
|
|
self.target:resizeAs(input):copy(input) |
|
|
end |
|
|
self.output = input |
|
|
return self.output |
|
|
end |
|
|
|
|
|
function ContentLoss:updateGradInput(input, gradOutput) |
|
|
if self.mode == 'loss' then |
|
|
if input:nElement() == self.target:nElement() then |
|
|
self.gradInput = self.crit:backward(input, self.target) |
|
|
end |
|
|
if self.normalize then |
|
|
self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8) |
|
|
end |
|
|
self.gradInput:mul(self.strength) |
|
|
self.gradInput:add(gradOutput) |
|
|
else |
|
|
self.gradInput:resizeAs(gradOutput):copy(gradOutput) |
|
|
end |
|
|
return self.gradInput |
|
|
end |
|
|
|
|
|
|
|
|
local Gram, parent = torch.class('nn.GramMatrix', 'nn.Module') |
|
|
|
|
|
function Gram:__init() |
|
|
parent.__init(self) |
|
|
end |
|
|
|
|
|
function Gram:updateOutput(input) |
|
|
assert(input:dim() == 3) |
|
|
local C, H, W = input:size(1), input:size(2), input:size(3) |
|
|
local x_flat = input:view(C, H * W) |
|
|
self.output:resize(C, C) |
|
|
self.output:mm(x_flat, x_flat:t()) |
|
|
return self.output |
|
|
end |
|
|
|
|
|
function Gram:updateGradInput(input, gradOutput) |
|
|
assert(input:dim() == 3 and input:size(1)) |
|
|
local C, H, W = input:size(1), input:size(2), input:size(3) |
|
|
local x_flat = input:view(C, H * W) |
|
|
self.gradInput:resize(C, H * W):mm(gradOutput, x_flat) |
|
|
self.gradInput:addmm(gradOutput:t(), x_flat) |
|
|
self.gradInput = self.gradInput:view(C, H, W) |
|
|
return self.gradInput |
|
|
end |
|
|
|
|
|
function inputMask(C, H, W) |
|
|
local t = torch.Tensor(C,H,W):fill(1) |
|
|
local m = torch.Tensor(C,H,W):fill(0.2) |
|
|
if style_channels ~= nil then |
|
|
for i=1,#style_channels do |
|
|
local sch = style_channels[i] |
|
|
--print(i, sch) |
|
|
if sch > C then |
|
|
print("skipping non-existent channel ",sch) |
|
|
end |
|
|
m[sch] = m[sch]*5 |
|
|
end |
|
|
end |
|
|
return t:cmul(m):cuda() |
|
|
end |
|
|
|
|
|
-- Define an nn Module to compute style loss in-place |
|
|
local StyleLoss, parent = torch.class('nn.StyleLoss', 'nn.Module') |
|
|
|
|
|
function StyleLoss:__init(strength, normalize) |
|
|
parent.__init(self) |
|
|
self.normalize = normalize or false |
|
|
self.strength = strength |
|
|
self.target = torch.Tensor() |
|
|
self.mode = 'none' |
|
|
self.loss = 0 |
|
|
self.mask = nil |
|
|
|
|
|
self.gram = nn.GramMatrix() |
|
|
self.blend_weight = nil |
|
|
self.G = nil |
|
|
self.crit = nn.MSECriterion() |
|
|
end |
|
|
|
|
|
function StyleLoss:updateOutput(input) |
|
|
if (self.mask == nil) or (self.mask:size() ~= input:size()) then |
|
|
self.mask = inputMask(input:size(1), input:size(2), input:size(3)) |
|
|
end |
|
|
--print('updateOutput', input:size(), self.mask:size()) |
|
|
local inp = input:clone():cmul(self.mask) |
|
|
self.G = self.gram:forward(inp) |
|
|
self.G:div(inp:nElement()) |
|
|
if self.mode == 'capture' then |
|
|
if self.blend_weight == nil then |
|
|
self.target:resizeAs(self.G):copy(self.G) |
|
|
elseif self.target:nElement() == 0 then |
|
|
self.target:resizeAs(self.G):copy(self.G):mul(self.blend_weight) |
|
|
else |
|
|
self.target:add(self.blend_weight, self.G) |
|
|
end |
|
|
elseif self.mode == 'loss' then |
|
|
self.loss = self.strength * self.crit:forward(self.G, self.target) |
|
|
end |
|
|
self.output = input |
|
|
return self.output |
|
|
end |
|
|
|
|
|
function StyleLoss:updateGradInput(input, gradOutput) |
|
|
if self.mask == nil then |
|
|
self.mask = inputMask(input:size(1), input:size(2), input:size(3)) |
|
|
end |
|
|
--print('updateGradInput', input:size(), self.mask:size()) |
|
|
local inp = input:clone():cmul(self.mask) |
|
|
if self.mode == 'loss' then |
|
|
local dG = self.crit:backward(self.G, self.target) |
|
|
dG:div(input:nElement()) |
|
|
self.gradInput = self.gram:backward(inp, dG) |
|
|
if self.normalize then |
|
|
self.gradInput:div(torch.norm(self.gradInput, 1) + 1e-8) |
|
|
end |
|
|
self.gradInput:mul(self.strength) |
|
|
self.gradInput:add(gradOutput) |
|
|
else |
|
|
self.gradInput = gradOutput |
|
|
end |
|
|
return self.gradInput |
|
|
end |
|
|
|
|
|
|
|
|
local TVLoss, parent = torch.class('nn.TVLoss', 'nn.Module') |
|
|
|
|
|
function TVLoss:__init(strength) |
|
|
parent.__init(self) |
|
|
self.strength = strength |
|
|
self.x_diff = torch.Tensor() |
|
|
self.y_diff = torch.Tensor() |
|
|
end |
|
|
|
|
|
function TVLoss:updateOutput(input) |
|
|
self.output = input |
|
|
return self.output |
|
|
end |
|
|
|
|
|
-- TV loss backward pass inspired by kaishengtai/neuralart |
|
|
function TVLoss:updateGradInput(input, gradOutput) |
|
|
self.gradInput:resizeAs(input):zero() |
|
|
local C, H, W = input:size(1), input:size(2), input:size(3) |
|
|
self.x_diff:resize(3, H - 1, W - 1) |
|
|
self.y_diff:resize(3, H - 1, W - 1) |
|
|
self.x_diff:copy(input[{{}, {1, -2}, {1, -2}}]) |
|
|
self.x_diff:add(-1, input[{{}, {1, -2}, {2, -1}}]) |
|
|
self.y_diff:copy(input[{{}, {1, -2}, {1, -2}}]) |
|
|
self.y_diff:add(-1, input[{{}, {2, -1}, {1, -2}}]) |
|
|
self.gradInput[{{}, {1, -2}, {1, -2}}]:add(self.x_diff):add(self.y_diff) |
|
|
self.gradInput[{{}, {1, -2}, {2, -1}}]:add(-1, self.x_diff) |
|
|
self.gradInput[{{}, {2, -1}, {1, -2}}]:add(-1, self.y_diff) |
|
|
self.gradInput:mul(self.strength) |
|
|
self.gradInput:add(gradOutput) |
|
|
return self.gradInput |
|
|
end |
|
|
|
|
|
|
|
|
local params = cmd:parse(arg) |
|
|
main(params) |