import torch import torch.nn as nn class CompactBilinearPooling(nn.Module): def forward(self, bottom1, bottom2): assert bottom1.size(1) == self.input_dim1 and bottom2.size(1) == self.input_dim2 batch_size, _, height, width = bottom1.size() bottom1_flat = bottom1.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1) bottom2_flat = bottom2.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim2) sketch_1 = bottom1_flat.mm(self.sparse_sketch_matrix1) sketch_2 = bottom2_flat.mm(self.sparse_sketch_matrix2) fft1_real, fft1_imag = torch.rfft(sketch_1, 1).permute(2, 0, 1) fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1) fft_product_real = fft1_real * fft2_real - fft1_imag * fft2_imag fft_product_imag = fft1_real * fft2_imag - fft1_imag * fft2_real cbp_flat = torch.irfft(torch.stack([fft_product_real, fft_product_imag], dim = -1) cbp = cbp_flat.view(batch_size, height, width, self.output_dim) * self.output_dim if self.sum_pool: return cbp.sum(dim = 1).sum(dim = 1) return cbp.permute(0, 3, 1, 2)