Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active September 22, 2021 07:51
Show Gist options
  • Select an option

  • Save vadimkantorov/d9b56f9b85f1f4ce59ffecf893a1581a to your computer and use it in GitHub Desktop.

Select an option

Save vadimkantorov/d9b56f9b85f1f4ce59ffecf893a1581a to your computer and use it in GitHub Desktop.
Compact Bilinear Pooling in PyTorch using the new FFT support
import torch
import torch.nn as nn
class CompactBilinearPooling(nn.Module):
def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1 = None, rand_s_1 = None, rand_h_2 = None, rand_s_2 = None):
super(CompactBilinearPooling, self).__init__()
self.input_dim1 = input_dim1
self.input_dim2 = input_dim2
self.output_dim = output_dim
self.sum_pool = sum_pool
if rand_h_1 is None:
rand_h_1 = torch.randint(output_dim, size = (self.input_dim1,))
if rand_s_1 is None:
rand_s_1 = 2 * torch.randint(2, size = (self.input_dim1,)) - 1
if rand_h_2 is None:
rand_h_2 = torch.randint(output_dim, size = (self.input_dim2,))
if rand_s_2 is None:
rand_s_2 = 2 * torch.randint(2, size = (self.input_dim2, )) - 1
self.sparse_sketch_matrix1 = self.generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, self.output_dim)
self.sparse_sketch_matrix2 = self.generate_sketch_matrix(rand_h_2, rand_s_2, input_dim2, self.output_dim)
def forward(self, bottom1, bottom2):
sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().mm(self.sparse_sketch_matrix1).view(-1, output_dim)
sketch_2 = bottom2.permute(0, 2, 3, 1).contiguous().mm(self.sparse_sketch_matrix2).view(-1, output_dim)
fft1_real, fft1_imag = torch.rfft(sketch_1, 1).permute(2, 0, 1)
fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1)
fft_product = torch.stack([fft1_real * fft2_real - fft1_imag * fft2_imag, fft1_real * fft2_imag - fft1_imag * fft2_real], dim = -1)
cbp = torch.irfft(fft_product).view(len(bottom1), bottom1.size(-2), bottom1.size(-1), self.output_dim) * self.output_dim
return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)
@staticmethod
def generate_sketch_matrix(rand_h, rand_s, input_dim, output_dim):
indices = np.concatenate((torch.arange(input_dim)[..., np.newaxis], rand_h[..., np.newaxis]), axis=1)
return torch.sparse.FloatTensor(indices.t(), rand_s, torch.Size([input_dim, output_dim])).to_dense()
@pangjh3
Copy link

pangjh3 commented Apr 20, 2018

Thanks for your code, how to install the new fft support?

@vadimkantorov
Copy link
Author

Just install PyTorch from master branch or even 0.4 version probably has FFT

@ayumiymk
Copy link

Thanks for your code first. I have a question that in the other implements, like Torch version and Tensorflow version, there is a zero_padding before feeding the tensor into the fft. But in this code, I don't see the zero_padding.

Thanks very much!

@hj0921
Copy link

hj0921 commented Mar 19, 2021

hello,

torch.stack([torch.arange(in_features), rand_h]) where in_features is not defined. How to fix it?

thanks!

@vadimkantorov
Copy link
Author

Thanks for noting this. Fixed! It should have been in_channels

@vadimkantorov
Copy link
Author

Some ways to improve the code: make use of the new PyTorch fft module, complex support. Figure out dense x sparse matmul (currently I'm materializing the sparse sketch)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment