Skip to content

Instantly share code, notes, and snippets.

@karpathy
Last active July 6, 2019 13:34
Show Gist options
  • Select an option

  • Save karpathy/f3ee599538ff78e1bbe9 to your computer and use it in GitHub Desktop.

Select an option

Save karpathy/f3ee599538ff78e1bbe9 to your computer and use it in GitHub Desktop.

Revisions

  1. karpathy revised this gist May 6, 2015. 1 changed file with 2 additions and 24 deletions.
    26 changes: 2 additions & 24 deletions gistfile1.lua
    Original file line number Diff line number Diff line change
    @@ -1,3 +1,4 @@

    --[[
    This layer expects an [n x d] Tensor and normalizes each
    row to have unit L2 norm.
    @@ -32,31 +33,8 @@ function L2Normalize:updateGradInput(input, gradOutput)
    local b2 = input:view(n,1,d)
    self.diag:add(-torch.bmm(b1,b2))
    -- compute the local gradient of the L2 transformation
    self.buffer:pow(3)
    self.diag:cdiv(self.buffer:view(n,1,1):expand(n,d,d))
    self.diag:cdiv(torch.pow(self.buffer,3):view(n,1,1):expand(n,d,d))
    -- chain the gradient
    self.gradInput:resize(n,d,1):bmm(self.diag, gradOutput:view(n,d,1)):resize(n,d)
    return self.gradInput
    end


    --[[
    -- for reference, a cleaner but slower implementation that loops
    -- over all input vectors
    function L2Normalize:updateGradInputSlow(input, gradOutput)
    self.gradInput:resizeAs(gradOutput)
    local n = input:size(1)
    local d = input:size(2)
    for i=1,n do
    local x = input[{{i,i}}]
    local s = torch.sum(torch.cmul(x,x))
    local diag = torch.eye(d) * s
    local cross = - torch.mm(x:t(),x)
    local divterm = torch.pow(s,3/2)
    self.gradInput[{i}] = torch.mm((diag + cross) / divterm, gradOutput[{{i,i}}]:t())
    end
    return self.gradInput
    end
    ]]--
  2. karpathy revised this gist May 6, 2015. 1 changed file with 14 additions and 12 deletions.
    26 changes: 14 additions & 12 deletions gistfile1.lua
    Original file line number Diff line number Diff line change
    @@ -10,8 +10,11 @@ function L2Normalize:updateOutput(input)
    assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got '
    .. input:dim() .. 'D tensor instead')
    self.output:resizeAs(input)
    local norms = torch.cmul(input,input):sum(2):sqrt()
    self.output:copy(input):cdiv(norms:expandAs(input))
    self.buffer = self.buffer or input.new()
    self.normSquared = self.normSquared or input.new()
    self.normSquared:sum(self.buffer:cmul(input, input), 2)
    self.buffer:sqrt(self.normSquared)
    self.output:copy(input):cdiv(self.buffer:expandAs(input))
    return self.output
    end

    @@ -20,20 +23,19 @@ function L2Normalize:updateGradInput(input, gradOutput)
    assert(gradOutput:dim() == 2, 'only mini-batch supported')
    local n = input:size(1) -- batch size
    local d = input:size(2) -- dimensionality of vectors

    local sums = torch.sum(torch.cmul(input,input), 2):view(-1)
    local divterms = torch.pow(sums,3/2)
    -- compute diagonal term
    local diag = torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d)
    :cmul(sums:view(n,1,1):expand(n,d,d))
    self.eye = self.eye or torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d)
    self.diag = self.diag or self.eye.new()
    self.diag:cmul(self.eye, self.normSquared:view(n,1,1):expand(n,d,d))
    -- compute cross term
    local b1 = input:reshape(n,d,1)
    local b2 = input:reshape(n,1,d)
    local cross = - torch.bmm(b1,b2)
    local b1 = input:view(n,d,1)
    local b2 = input:view(n,1,d)
    self.diag:add(-torch.bmm(b1,b2))
    -- compute the local gradient of the L2 transformation
    local dsum = torch.cdiv(diag + cross, divterms:view(n,1,1):expand(n,d,d))
    self.buffer:pow(3)
    self.diag:cdiv(self.buffer:view(n,1,1):expand(n,d,d))
    -- chain the gradient
    self.gradInput = torch.bmm(dsum, gradOutput:view(n,d,1)):squeeze()
    self.gradInput:resize(n,d,1):bmm(self.diag, gradOutput:view(n,d,1)):resize(n,d)
    return self.gradInput
    end

  3. karpathy revised this gist May 5, 2015. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion gistfile1.lua
    Original file line number Diff line number Diff line change
    @@ -33,7 +33,7 @@ function L2Normalize:updateGradInput(input, gradOutput)
    -- compute the local gradient of the L2 transformation
    local dsum = torch.cdiv(diag + cross, divterms:view(n,1,1):expand(n,d,d))
    -- chain the gradient
    self.gradInput = torch.bmm(dsum, gradOutput:view(n,d,1))
    self.gradInput = torch.bmm(dsum, gradOutput:view(n,d,1)):squeeze()
    return self.gradInput
    end

  4. karpathy revised this gist May 5, 2015. 1 changed file with 29 additions and 23 deletions.
    52 changes: 29 additions & 23 deletions gistfile1.lua
    Original file line number Diff line number Diff line change
    @@ -1,37 +1,43 @@
    --[[
    This layer expects [n x d] input and normalizes the rows. That is, each
    d-dimensional row is treated as a vector that is (L2) normalized
    This layer expects an [n x d] Tensor and normalizes each
    row to have unit L2 norm.
    ]]--
    local L2Normalize, parent = torch.class('nn.L2Normalize', 'nn.Module')
    function L2Normalize:__init()
    parent.__init(self)
    parent.__init(self)
    end
    function L2Normalize:updateOutput(input)
    self.output:resizeAs(input)
    local norms = torch.cmul(input,input):sum(2):sqrt()
    self.output:copy(input):cdiv(norms:expandAs(input))
    return self.output
    assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got '
    .. input:dim() .. 'D tensor instead')
    self.output:resizeAs(input)
    local norms = torch.cmul(input,input):sum(2):sqrt()
    self.output:copy(input):cdiv(norms:expandAs(input))
    return self.output
    end

    function L2Normalize:updateGradInput(input, gradOutput)
    self.gradInput:resizeAs(gradOutput)
    local n = input:size(1) -- batch size
    local d = input:size(2) -- dimensionality of vectors
    local sums = torch.sum(torch.cmul(input,input), 2):view(-1)
    local divterms = torch.pow(sums,3/2)
    -- compute diagonal term
    local diag_mat = torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d):cmul(sums:view(n,1,1):expand(n,d,d))
    -- compute cross term
    local b1 = input:reshape(n,d,1)
    local b2 = input:reshape(n,1,d)
    local cross_mat = - torch.bmm(b1,b2)
    -- compute the local gradient of the L2 transformation
    local dsum = torch.cdiv(diag_mat + cross_mat, divterms:view(n,1,1):expand(n,d,d))
    -- chain the gradient
    self.gradInput = torch.bmm(dsum, gradOutput:view(n,d,1))
    return self.gradInput
    assert(input:dim() == 2, 'only mini-batch supported')
    assert(gradOutput:dim() == 2, 'only mini-batch supported')
    local n = input:size(1) -- batch size
    local d = input:size(2) -- dimensionality of vectors

    local sums = torch.sum(torch.cmul(input,input), 2):view(-1)
    local divterms = torch.pow(sums,3/2)
    -- compute diagonal term
    local diag = torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d)
    :cmul(sums:view(n,1,1):expand(n,d,d))
    -- compute cross term
    local b1 = input:reshape(n,d,1)
    local b2 = input:reshape(n,1,d)
    local cross = - torch.bmm(b1,b2)
    -- compute the local gradient of the L2 transformation
    local dsum = torch.cdiv(diag + cross, divterms:view(n,1,1):expand(n,d,d))
    -- chain the gradient
    self.gradInput = torch.bmm(dsum, gradOutput:view(n,d,1))
    return self.gradInput
    end


    --[[
    -- for reference, a cleaner but slower implementation that loops
    -- over all input vectors
  5. karpathy revised this gist May 5, 2015. 1 changed file with 22 additions and 1 deletion.
    23 changes: 22 additions & 1 deletion gistfile1.lua
    Original file line number Diff line number Diff line change
    @@ -30,4 +30,25 @@ function L2Normalize:updateGradInput(input, gradOutput)
    -- chain the gradient
    self.gradInput = torch.bmm(dsum, gradOutput:view(n,d,1))
    return self.gradInput
    end
    end

    --[[
    -- for reference, a cleaner but slower implementation that loops
    -- over all input vectors
    function L2Normalize:updateGradInputSlow(input, gradOutput)
    self.gradInput:resizeAs(gradOutput)
    local n = input:size(1)
    local d = input:size(2)
    for i=1,n do
    local x = input[{{i,i}}]
    local s = torch.sum(torch.cmul(x,x))
    local diag = torch.eye(d) * s
    local cross = - torch.mm(x:t(),x)
    local divterm = torch.pow(s,3/2)
    self.gradInput[{i}] = torch.mm((diag + cross) / divterm, gradOutput[{{i,i}}]:t())
    end
    return self.gradInput
    end
    ]]--
  6. karpathy renamed this gist May 5, 2015. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  7. karpathy created this gist May 5, 2015.
    33 changes: 33 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,33 @@
    --[[
    This layer expects [n x d] input and normalizes the rows. That is, each
    d-dimensional row is treated as a vector that is (L2) normalized
    ]]--
    local L2Normalize, parent = torch.class('nn.L2Normalize', 'nn.Module')
    function L2Normalize:__init()
    parent.__init(self)
    end
    function L2Normalize:updateOutput(input)
    self.output:resizeAs(input)
    local norms = torch.cmul(input,input):sum(2):sqrt()
    self.output:copy(input):cdiv(norms:expandAs(input))
    return self.output
    end

    function L2Normalize:updateGradInput(input, gradOutput)
    self.gradInput:resizeAs(gradOutput)
    local n = input:size(1) -- batch size
    local d = input:size(2) -- dimensionality of vectors
    local sums = torch.sum(torch.cmul(input,input), 2):view(-1)
    local divterms = torch.pow(sums,3/2)
    -- compute diagonal term
    local diag_mat = torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d):cmul(sums:view(n,1,1):expand(n,d,d))
    -- compute cross term
    local b1 = input:reshape(n,d,1)
    local b2 = input:reshape(n,1,d)
    local cross_mat = - torch.bmm(b1,b2)
    -- compute the local gradient of the L2 transformation
    local dsum = torch.cdiv(diag_mat + cross_mat, divterms:view(n,1,1):expand(n,d,d))
    -- chain the gradient
    self.gradInput = torch.bmm(dsum, gradOutput:view(n,d,1))
    return self.gradInput
    end