Last active
July 6, 2019 13:34
-
-
Save karpathy/f3ee599538ff78e1bbe9 to your computer and use it in GitHub Desktop.
Revisions
-
karpathy revised this gist
May 6, 2015 . 1 changed file with 2 additions and 24 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,3 +1,4 @@ --[[ This layer expects an [n x d] Tensor and normalizes each row to have unit L2 norm. @@ -32,31 +33,8 @@ function L2Normalize:updateGradInput(input, gradOutput) local b2 = input:view(n,1,d) self.diag:add(-torch.bmm(b1,b2)) -- compute the local gradient of the L2 transformation self.diag:cdiv(torch.pow(self.buffer,3):view(n,1,1):expand(n,d,d)) -- chain the gradient self.gradInput:resize(n,d,1):bmm(self.diag, gradOutput:view(n,d,1)):resize(n,d) return self.gradInput end -
karpathy revised this gist
May 6, 2015 . 1 changed file with 14 additions and 12 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -10,8 +10,11 @@ function L2Normalize:updateOutput(input) assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got ' .. input:dim() .. 'D tensor instead') self.output:resizeAs(input) self.buffer = self.buffer or input.new() self.normSquared = self.normSquared or input.new() self.normSquared:sum(self.buffer:cmul(input, input), 2) self.buffer:sqrt(self.normSquared) self.output:copy(input):cdiv(self.buffer:expandAs(input)) return self.output end @@ -20,20 +23,19 @@ function L2Normalize:updateGradInput(input, gradOutput) assert(gradOutput:dim() == 2, 'only mini-batch supported') local n = input:size(1) -- batch size local d = input:size(2) -- dimensionality of vectors -- compute diagonal term self.eye = self.eye or torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d) self.diag = self.diag or self.eye.new() self.diag:cmul(self.eye, self.normSquared:view(n,1,1):expand(n,d,d)) -- compute cross term local b1 = input:view(n,d,1) local b2 = input:view(n,1,d) self.diag:add(-torch.bmm(b1,b2)) -- compute the local gradient of the L2 transformation self.buffer:pow(3) self.diag:cdiv(self.buffer:view(n,1,1):expand(n,d,d)) -- chain the gradient self.gradInput:resize(n,d,1):bmm(self.diag, gradOutput:view(n,d,1)):resize(n,d) return self.gradInput end -
karpathy revised this gist
May 5, 2015 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -33,7 +33,7 @@ function L2Normalize:updateGradInput(input, gradOutput) -- compute the local gradient of the L2 transformation local dsum = torch.cdiv(diag + cross, divterms:view(n,1,1):expand(n,d,d)) -- chain the gradient self.gradInput = torch.bmm(dsum, gradOutput:view(n,d,1)):squeeze() return self.gradInput end -
karpathy revised this gist
May 5, 2015 . 1 changed file with 29 additions and 23 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,37 +1,43 @@ --[[ This layer expects an [n x d] Tensor and normalizes each row to have unit L2 norm. ]]-- local L2Normalize, parent = torch.class('nn.L2Normalize', 'nn.Module') function L2Normalize:__init() parent.__init(self) end function L2Normalize:updateOutput(input) assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got ' .. input:dim() .. 'D tensor instead') self.output:resizeAs(input) local norms = torch.cmul(input,input):sum(2):sqrt() self.output:copy(input):cdiv(norms:expandAs(input)) return self.output end function L2Normalize:updateGradInput(input, gradOutput) assert(input:dim() == 2, 'only mini-batch supported') assert(gradOutput:dim() == 2, 'only mini-batch supported') local n = input:size(1) -- batch size local d = input:size(2) -- dimensionality of vectors local sums = torch.sum(torch.cmul(input,input), 2):view(-1) local divterms = torch.pow(sums,3/2) -- compute diagonal term local diag = torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d) :cmul(sums:view(n,1,1):expand(n,d,d)) -- compute cross term local b1 = input:reshape(n,d,1) local b2 = input:reshape(n,1,d) local cross = - torch.bmm(b1,b2) -- compute the local gradient of the L2 transformation local dsum = torch.cdiv(diag + cross, divterms:view(n,1,1):expand(n,d,d)) -- chain the gradient self.gradInput = torch.bmm(dsum, gradOutput:view(n,d,1)) return self.gradInput end --[[ -- for reference, a cleaner but slower implementation that loops -- over all input vectors -
karpathy revised this gist
May 5, 2015 . 1 changed file with 22 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -30,4 +30,25 @@ function L2Normalize:updateGradInput(input, gradOutput) -- chain the gradient self.gradInput = torch.bmm(dsum, gradOutput:view(n,d,1)) return self.gradInput end --[[ -- for reference, a cleaner but slower implementation that loops -- over all input vectors function L2Normalize:updateGradInputSlow(input, gradOutput) self.gradInput:resizeAs(gradOutput) local n = input:size(1) local d = input:size(2) for i=1,n do local x = input[{{i,i}}] local s = torch.sum(torch.cmul(x,x)) local diag = torch.eye(d) * s local cross = - torch.mm(x:t(),x) local divterm = torch.pow(s,3/2) self.gradInput[{i}] = torch.mm((diag + cross) / divterm, gradOutput[{{i,i}}]:t()) end return self.gradInput end ]]-- -
karpathy renamed this gist
May 5, 2015 . 1 changed file with 0 additions and 0 deletions.There are no files selected for viewing
File renamed without changes. -
karpathy created this gist
May 5, 2015 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,33 @@ --[[ This layer expects [n x d] input and normalizes the rows. That is, each d-dimensional row is treated as a vector that is (L2) normalized ]]-- local L2Normalize, parent = torch.class('nn.L2Normalize', 'nn.Module') function L2Normalize:__init() parent.__init(self) end function L2Normalize:updateOutput(input) self.output:resizeAs(input) local norms = torch.cmul(input,input):sum(2):sqrt() self.output:copy(input):cdiv(norms:expandAs(input)) return self.output end function L2Normalize:updateGradInput(input, gradOutput) self.gradInput:resizeAs(gradOutput) local n = input:size(1) -- batch size local d = input:size(2) -- dimensionality of vectors local sums = torch.sum(torch.cmul(input,input), 2):view(-1) local divterms = torch.pow(sums,3/2) -- compute diagonal term local diag_mat = torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d):cmul(sums:view(n,1,1):expand(n,d,d)) -- compute cross term local b1 = input:reshape(n,d,1) local b2 = input:reshape(n,1,d) local cross_mat = - torch.bmm(b1,b2) -- compute the local gradient of the L2 transformation local dsum = torch.cdiv(diag_mat + cross_mat, divterms:view(n,1,1):expand(n,d,d)) -- chain the gradient self.gradInput = torch.bmm(dsum, gradOutput:view(n,d,1)) return self.gradInput end