Created
February 6, 2018 08:25
-
-
Save lantiga/a78581e6c6c0ad1534065950e204ce9d to your computer and use it in GitHub Desktop.
Revisions
-
lantiga created this gist
Feb 6, 2018 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,27 @@ # Indexed convolutions A convolution operator over a 1D tensor (BxCxL), where a list of neighbors for each element is provided through a indices tensor (LxK), where K is the size of the convolution kernel. Each row of indices specifies the indices of the K neighbors of the corresponding element in the input. A -1 is handled like for zero padding. Note that the neighbors specified in indices are not relative, but rather absolute. They have to be specified for each of the elements of the output. A use case is for convolutions over non-square lattices, such as images on hexagonal lattices coming from Cherenkov telescopes (http://www.isdc.unige.ch/%7Elyard/FirstLight/FirstLight_slowHD.mov). Example: ``` import torch # a 1D input of 5 elems input = torch.randn(1,1,5) # this specifies the indices of neighbors for # each elem of the input (a 3 elem kernel here) # A -1 corresponds to zero-padding indices = torch.ones(5,3).type(torch.LongTensor) weight = torch.randn(1,1,3) bias = torch.randn(1) output = torch.nn.functional.indexed_conv(input, indices, weight, bias) ``` This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,67 @@ import torch from torch.autograd import Variable def prepare_mask(indices): padded = indices == -1 indices[padded] = 0 mask = torch.FloatTensor([1,0]) mask = mask[..., padded.t().long()] return indices, mask def indexed_conv(input, weight, bias, indices, mask): nbatch = input.shape[0] output_width = indices.shape[0] out_chans, in_chans, ksize = weight.shape if isinstance(input, Variable): mask = Variable(mask) col = input[..., indices.t()] * mask col = col.view(nbatch, -1, output_width) weight_col = weight.view(out_chans, -1) out = torch.matmul(weight_col, col) + bias #print(col) #print(weight_col) return out if __name__ == '__main__': # input = torch.randn(1,2,5) # weight = torch.randn(1,2,3) # bias = torch.randn(1) # indices = (5 * torch.rand(4,3)).long() input = torch.ones(1,2,5) weight = torch.ones(1,2,3) bias = torch.zeros(1) indices = (5 * torch.rand(4,3)).long() indices[0,0] = -1 indices, mask = prepare_mask(indices) print(input) print(indices) out = indexed_conv(input, weight, bias, indices, mask) input = Variable(input, requires_grad=True) weight = Variable(weight) bias = Variable(bias) out = indexed_conv(input, weight, bias, indices, mask) print(out) out.sum().backward() print(input.grad)