import torch from torch.autograd import Variable def prepare_mask(indices): padded = indices == -1 indices[padded] = 0 mask = torch.FloatTensor([1,0]) mask = mask[..., padded.t().long()] return indices, mask def indexed_conv(input, weight, bias, indices, mask): nbatch = input.shape[0] output_width = indices.shape[0] out_chans, in_chans, ksize = weight.shape if isinstance(input, Variable): mask = Variable(mask) col = input[..., indices.t()] * mask col = col.view(nbatch, -1, output_width) weight_col = weight.view(out_chans, -1) out = torch.matmul(weight_col, col) + bias #print(col) #print(weight_col) return out if __name__ == '__main__': # input = torch.randn(1,2,5) # weight = torch.randn(1,2,3) # bias = torch.randn(1) # indices = (5 * torch.rand(4,3)).long() input = torch.ones(1,2,5) weight = torch.ones(1,2,3) bias = torch.zeros(1) indices = (5 * torch.rand(4,3)).long() indices[0,0] = -1 indices, mask = prepare_mask(indices) print(input) print(indices) out = indexed_conv(input, weight, bias, indices, mask) input = Variable(input, requires_grad=True) weight = Variable(weight) bias = Variable(bias) out = indexed_conv(input, weight, bias, indices, mask) print(out) out.sum().backward() print(input.grad)