Skip to content

Instantly share code, notes, and snippets.

@lantiga
Created February 6, 2018 08:25
Show Gist options
  • Save lantiga/a78581e6c6c0ad1534065950e204ce9d to your computer and use it in GitHub Desktop.
Save lantiga/a78581e6c6c0ad1534065950e204ce9d to your computer and use it in GitHub Desktop.

Revisions

  1. lantiga created this gist Feb 6, 2018.
    27 changes: 27 additions & 0 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,27 @@

    # Indexed convolutions

    A convolution operator over a 1D tensor (BxCxL), where a list of neighbors for each element is provided through a indices tensor (LxK), where K is the size of the convolution kernel. Each row of indices specifies the indices of the K neighbors of the corresponding element in the input. A -1 is handled like for zero padding.

    Note that the neighbors specified in indices are not relative, but rather absolute. They have to be specified for each of the elements of the output.

    A use case is for convolutions over non-square lattices, such as images on hexagonal lattices coming from Cherenkov telescopes (http://www.isdc.unige.ch/%7Elyard/FirstLight/FirstLight_slowHD.mov).

    Example:

    ```
    import torch
    # a 1D input of 5 elems
    input = torch.randn(1,1,5)
    # this specifies the indices of neighbors for
    # each elem of the input (a 3 elem kernel here)
    # A -1 corresponds to zero-padding
    indices = torch.ones(5,3).type(torch.LongTensor)
    weight = torch.randn(1,1,3)
    bias = torch.randn(1)
    output = torch.nn.functional.indexed_conv(input, indices, weight, bias)
    ```
    67 changes: 67 additions & 0 deletions indexed_conv.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,67 @@
    import torch
    from torch.autograd import Variable

    def prepare_mask(indices):

    padded = indices == -1
    indices[padded] = 0

    mask = torch.FloatTensor([1,0])
    mask = mask[..., padded.t().long()]

    return indices, mask


    def indexed_conv(input, weight, bias, indices, mask):

    nbatch = input.shape[0]
    output_width = indices.shape[0]
    out_chans, in_chans, ksize = weight.shape

    if isinstance(input, Variable):
    mask = Variable(mask)

    col = input[..., indices.t()] * mask
    col = col.view(nbatch, -1, output_width)

    weight_col = weight.view(out_chans, -1)

    out = torch.matmul(weight_col, col) + bias

    #print(col)
    #print(weight_col)

    return out


    if __name__ == '__main__':

    # input = torch.randn(1,2,5)
    # weight = torch.randn(1,2,3)
    # bias = torch.randn(1)
    # indices = (5 * torch.rand(4,3)).long()

    input = torch.ones(1,2,5)
    weight = torch.ones(1,2,3)
    bias = torch.zeros(1)
    indices = (5 * torch.rand(4,3)).long()
    indices[0,0] = -1

    indices, mask = prepare_mask(indices)

    print(input)
    print(indices)

    out = indexed_conv(input, weight, bias, indices, mask)

    input = Variable(input, requires_grad=True)
    weight = Variable(weight)
    bias = Variable(bias)

    out = indexed_conv(input, weight, bias, indices, mask)

    print(out)

    out.sum().backward()
    print(input.grad)