Skip to content

Instantly share code, notes, and snippets.

@vadimkantorov
Last active September 22, 2021 07:51
Show Gist options
  • Save vadimkantorov/d9b56f9b85f1f4ce59ffecf893a1581a to your computer and use it in GitHub Desktop.
Save vadimkantorov/d9b56f9b85f1f4ce59ffecf893a1581a to your computer and use it in GitHub Desktop.

Revisions

  1. vadimkantorov revised this gist Sep 22, 2021. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -3,7 +3,7 @@
    # [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062
    # [3] Fast and Scalable Polynomial Kernels via Explicit Feature Maps, Pham and Pagh, https://chbrown.github.io/kdd-2013-usb/kdd/p239.pdf
    # [4] Fastfood — Approximating Kernel Expansions in Loglinear Time, Le et al., https://arxiv.org/abs/1408.3060
    # Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling
    # [5] Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling

    # TODO: migrate to use of new native complex64 types
    # TODO: change strided x coo matmul to torch.matmul(): M[sparse_coo] @ M[strided] -> M[strided]
  2. vadimkantorov revised this gist Sep 22, 2021. 1 changed file with 3 additions and 0 deletions.
    3 changes: 3 additions & 0 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -5,6 +5,9 @@
    # [4] Fastfood — Approximating Kernel Expansions in Loglinear Time, Le et al., https://arxiv.org/abs/1408.3060
    # Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling

    # TODO: migrate to use of new native complex64 types
    # TODO: change strided x coo matmul to torch.matmul(): M[sparse_coo] @ M[strided] -> M[strided]

    import torch

    class CompactBilinearPooling(torch.nn.Module):
  3. vadimkantorov revised this gist Mar 19, 2021. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -12,7 +12,7 @@ def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True):
    super().__init__()
    self.out_channels = out_channels
    self.sum_pool = sum_pool
    generate_tensor_sketch = lambda rand_h, rand_s, in_channels, out_channels: torch.sparse.FloatTensor(torch.stack([torch.arange(in_features), rand_h]), rand_s, [in_channels, out_channels]).to_dense()
    generate_tensor_sketch = lambda rand_h, rand_s, in_channels, out_channels: torch.sparse.FloatTensor(torch.stack([torch.arange(in_channels), rand_h]), rand_s, [in_channels, out_channels]).to_dense()
    self.tenosr_sketch1 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels1,)), 2 * torch.randint(2, size = (in_channels1,), dtype = torch.float32) - 1, in_channels1, out_channels), requires_grad = False)
    self.tensor_sketch2 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels2,)), 2 * torch.randint(2, size = (in_channels2,), dtype = torch.float32) - 1, in_channels2, out_channels), requires_grad = False)

  4. vadimkantorov revised this gist Apr 28, 2020. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -2,6 +2,7 @@
    # [1] Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding, Fukui et al., https://arxiv.org/abs/1606.01847
    # [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062
    # [3] Fast and Scalable Polynomial Kernels via Explicit Feature Maps, Pham and Pagh, https://chbrown.github.io/kdd-2013-usb/kdd/p239.pdf
    # [4] Fastfood — Approximating Kernel Expansions in Loglinear Time, Le et al., https://arxiv.org/abs/1408.3060
    # Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling

    import torch
  5. vadimkantorov revised this gist Apr 28, 2020. 1 changed file with 0 additions and 1 deletion.
    1 change: 0 additions & 1 deletion compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -2,7 +2,6 @@
    # [1] Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding, Fukui et al., https://arxiv.org/abs/1606.01847
    # [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062
    # [3] Fast and Scalable Polynomial Kernels via Explicit Feature Maps, Pham and Pagh, https://chbrown.github.io/kdd-2013-usb/kdd/p239.pdf

    # Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling

    import torch
  6. vadimkantorov revised this gist Apr 28, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -19,7 +19,7 @@ def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True):
    def forward(self, x1, x2):
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.tensor_sketch1), signal_ndim = 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.tensor_sketch2), signal_ndim = 1)
    # torch.rfft does not support yet returning torch.complex64 outputs, so we do complex product manually
    # torch.rfft does not support yet torch.complex64 outputs, so we do complex product manually
    fft_complex_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1)
    cbp = torch.irfft(fft_complex_product, signal_ndim = 1, signal_sizes = (self.out_channels, )) * self.out_channels
    return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2)
  7. vadimkantorov revised this gist Apr 28, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -9,7 +9,7 @@

    class CompactBilinearPooling(torch.nn.Module):
    def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True):
    super(CompactBilinearPooling, self).__init__()
    super().__init__()
    self.out_channels = out_channels
    self.sum_pool = sum_pool
    generate_tensor_sketch = lambda rand_h, rand_s, in_channels, out_channels: torch.sparse.FloatTensor(torch.stack([torch.arange(in_features), rand_h]), rand_s, [in_channels, out_channels]).to_dense()
  8. vadimkantorov revised this gist Apr 28, 2020. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -13,8 +13,8 @@ def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True):
    self.out_channels = out_channels
    self.sum_pool = sum_pool
    generate_tensor_sketch = lambda rand_h, rand_s, in_channels, out_channels: torch.sparse.FloatTensor(torch.stack([torch.arange(in_features), rand_h]), rand_s, [in_channels, out_channels]).to_dense()
    self.tenosr_sketch1 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels1,)), 2 * torch.randint(2, size = (in_channels1,)) - 1, in_channels1, out_channels, dtype = torch.float32), requires_grad = False)
    self.tensor_sketch2 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels2,)), 2 * torch.randint(2, size = (in_channels2,)) - 1, in_channels2, out_channels, dtype = torch.float32), requires_grad = False)
    self.tenosr_sketch1 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels1,)), 2 * torch.randint(2, size = (in_channels1,), dtype = torch.float32) - 1, in_channels1, out_channels), requires_grad = False)
    self.tensor_sketch2 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels2,)), 2 * torch.randint(2, size = (in_channels2,), dtype = torch.float32) - 1, in_channels2, out_channels), requires_grad = False)

    def forward(self, x1, x2):
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.tensor_sketch1), signal_ndim = 1)
  9. vadimkantorov revised this gist Apr 28, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -3,7 +3,7 @@
    # [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062
    # [3] Fast and Scalable Polynomial Kernels via Explicit Feature Maps, Pham and Pagh, https://chbrown.github.io/kdd-2013-usb/kdd/p239.pdf

    # Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling/blob/master/caffe-20160312/src/caffe/layers/compact_bilinear_layer.cpp
    # Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling

    import torch

  10. vadimkantorov revised this gist Apr 28, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -17,7 +17,7 @@ def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True):
    self.tensor_sketch2 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels2,)), 2 * torch.randint(2, size = (in_channels2,)) - 1, in_channels2, out_channels, dtype = torch.float32), requires_grad = False)

    def forward(self, x1, x2):
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.tenosr_sketch1), signal_ndim = 1)
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.tensor_sketch1), signal_ndim = 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.tensor_sketch2), signal_ndim = 1)
    # torch.rfft does not support yet returning torch.complex64 outputs, so we do complex product manually
    fft_complex_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1)
  11. vadimkantorov revised this gist Apr 28, 2020. 1 changed file with 11 additions and 8 deletions.
    19 changes: 11 additions & 8 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -1,22 +1,25 @@
    # References:
    # [1] Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding, Fukui et al., https://arxiv.org/abs/1606.01847
    # [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062
    # [3] Fast and Scalable Polynomial Kernels via Explicit Feature Maps, Pham and Pagh, https://chbrown.github.io/kdd-2013-usb/kdd/p239.pdf

    # Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling/blob/master/caffe-20160312/src/caffe/layers/compact_bilinear_layer.cpp

    import torch

    class CompactBilinearPooling(torch.nn.Module):
    def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True):
    def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True):
    super(CompactBilinearPooling, self).__init__()
    self.output_dim = output_dim
    self.out_channels = out_channels
    self.sum_pool = sum_pool
    generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), [input_dim, output_dim]).to_dense()
    self.sketch1 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim), requires_grad = False)
    self.sketch2 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim), requires_grad = False)
    generate_tensor_sketch = lambda rand_h, rand_s, in_channels, out_channels: torch.sparse.FloatTensor(torch.stack([torch.arange(in_features), rand_h]), rand_s, [in_channels, out_channels]).to_dense()
    self.tenosr_sketch1 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels1,)), 2 * torch.randint(2, size = (in_channels1,)) - 1, in_channels1, out_channels, dtype = torch.float32), requires_grad = False)
    self.tensor_sketch2 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels2,)), 2 * torch.randint(2, size = (in_channels2,)) - 1, in_channels2, out_channels, dtype = torch.float32), requires_grad = False)

    def forward(self, x1, x2):
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch1), signal_ndim = 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch2), signal_ndim = 1)
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.tenosr_sketch1), signal_ndim = 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.tensor_sketch2), signal_ndim = 1)
    # torch.rfft does not support yet returning torch.complex64 outputs, so we do complex product manually
    fft_complex_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1)
    cbp = torch.irfft(fft_complex_product, signal_ndim = 1, signal_sizes = (self.output_dim, )) * self.output_dim
    cbp = torch.irfft(fft_complex_product, signal_ndim = 1, signal_sizes = (self.out_channels, )) * self.out_channels
    return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2)
  12. vadimkantorov revised this gist Apr 28, 2020. 1 changed file with 7 additions and 2 deletions.
    9 changes: 7 additions & 2 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -1,3 +1,7 @@
    # References:
    # [1] Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding, Fukui et al., https://arxiv.org/abs/1606.01847
    # [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062

    import torch

    class CompactBilinearPooling(torch.nn.Module):
    @@ -12,6 +16,7 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True):
    def forward(self, x1, x2):
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch1), signal_ndim = 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch2), signal_ndim = 1)
    fft_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1)
    cbp = torch.irfft(fft_product, signal_ndim = 1, signal_sizes = (self.output_dim, )) * self.output_dim
    # torch.rfft does not support yet returning torch.complex64 outputs, so we do complex product manually
    fft_complex_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1)
    cbp = torch.irfft(fft_complex_product, signal_ndim = 1, signal_sizes = (self.output_dim, )) * self.output_dim
    return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2)
  13. vadimkantorov revised this gist Mar 15, 2019. 1 changed file with 4 additions and 4 deletions.
    8 changes: 4 additions & 4 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -6,12 +6,12 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True):
    self.output_dim = output_dim
    self.sum_pool = sum_pool
    generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), [input_dim, output_dim]).to_dense()
    self.sketch_matrix1 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim), requires_grad = False)
    self.sketch_matrix2 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim), requires_grad = False)
    self.sketch1 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim), requires_grad = False)
    self.sketch2 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim), requires_grad = False)

    def forward(self, x1, x2):
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch_matrix1), signal_ndim = 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch_matrix2), signal_ndim = 1)
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch1), signal_ndim = 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch2), signal_ndim = 1)
    fft_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1)
    cbp = torch.irfft(fft_product, signal_ndim = 1, signal_sizes = (self.output_dim, )) * self.output_dim
    return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2)
  14. vadimkantorov revised this gist Mar 15, 2019. 1 changed file with 6 additions and 6 deletions.
    12 changes: 6 additions & 6 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -6,12 +6,12 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True):
    self.output_dim = output_dim
    self.sum_pool = sum_pool
    generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), [input_dim, output_dim]).to_dense()
    self.sketch_matrix1 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim))
    self.sketch_matrix2 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim))
    self.sketch_matrix1 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim), requires_grad = False)
    self.sketch_matrix2 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim), requires_grad = False)

    def forward(self, x1, x2):
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch_matrix1), 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch_matrix2), 1)
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch_matrix1), signal_ndim = 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch_matrix2), signal_ndim = 1)
    fft_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1)
    cbp = torch.irfft(fft_product, 1, signal_sizes = (self.output_dim,)) * self.output_dim
    return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)
    cbp = torch.irfft(fft_product, signal_ndim = 1, signal_sizes = (self.output_dim, )) * self.output_dim
    return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2)
  15. vadimkantorov revised this gist Apr 12, 2018. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -13,5 +13,5 @@ def forward(self, x1, x2):
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch_matrix1), 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch_matrix2), 1)
    fft_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1)
    cbp = torch.irfft(fft_product, 1, signal_sizes = (self.output_dim,)).view(len(x1), x1.size(-2), x1.size(-1), -1) * self.output_dim
    cbp = torch.irfft(fft_product, 1, signal_sizes = (self.output_dim,)) * self.output_dim
    return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)
  16. vadimkantorov revised this gist Apr 12, 2018. 1 changed file with 4 additions and 4 deletions.
    8 changes: 4 additions & 4 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -6,12 +6,12 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True):
    self.output_dim = output_dim
    self.sum_pool = sum_pool
    generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), [input_dim, output_dim]).to_dense()
    self.sketch_matrix1 = nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim))
    self.sketch_matrix2 = nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim))
    self.sketch_matrix1 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim))
    self.sketch_matrix2 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim))

    def forward(self, x1, x2):
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch_matrix1).view(-1, self.output_dim), 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch_matrix2).view(-1, self.output_dim), 1)
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch_matrix1), 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch_matrix2), 1)
    fft_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1)
    cbp = torch.irfft(fft_product, 1, signal_sizes = (self.output_dim,)).view(len(x1), x1.size(-2), x1.size(-1), -1) * self.output_dim
    return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)
  17. vadimkantorov revised this gist Apr 11, 2018. 1 changed file with 8 additions and 12 deletions.
    20 changes: 8 additions & 12 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -1,21 +1,17 @@
    import torch
    import torch.nn as nn

    class CompactBilinearPooling(nn.Module):
    class CompactBilinearPooling(torch.nn.Module):
    def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True):
    super(CompactBilinearPooling, self).__init__()
    self.output_dim = output_dim
    self.sum_pool = sum_pool

    generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), torch.Size([input_dim, output_dim])).to_dense()
    generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), [input_dim, output_dim]).to_dense()
    self.sketch_matrix1 = nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim))
    self.sketch_matrix2 = nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim))

    def forward(self, bottom1, bottom2):
    sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().matmul(self.sketch_matrix1).view(-1, self.output_dim)
    sketch_2 = bottom2.permute(0, 2, 3, 1).contiguous().matmul(self.sketch_matrix2).view(-1, self.output_dim)
    fft1_real, fft1_imag = torch.rfft(sketch_1, 1).permute(2, 0, 1)
    fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1)
    fft_product = torch.stack([fft1_real * fft2_real - fft1_imag * fft2_imag, fft1_real * fft2_imag - fft1_imag * fft2_real], dim = -1)
    cbp = torch.irfft(fft_product, 1, signal_sizes = (self.output_dim,)).view(len(bottom1), bottom1.size(-2), bottom1.size(-1), self.output_dim) * self.output_dim
    return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)
    def forward(self, x1, x2):
    fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch_matrix1).view(-1, self.output_dim), 1)
    fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch_matrix2).view(-1, self.output_dim), 1)
    fft_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1)
    cbp = torch.irfft(fft_product, 1, signal_sizes = (self.output_dim,)).view(len(x1), x1.size(-2), x1.size(-1), -1) * self.output_dim
    return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)
  18. vadimkantorov revised this gist Apr 11, 2018. 1 changed file with 15 additions and 24 deletions.
    39 changes: 15 additions & 24 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -2,29 +2,20 @@
    import torch.nn as nn

    class CompactBilinearPooling(nn.Module):
    def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1 = None, rand_s_1 = None, rand_h_2 = None, rand_s_2 = None):
    super(CompactBilinearPooling, self).__init__()
    self.output_dim = output_dim
    self.sum_pool = sum_pool
    def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True):
    super(CompactBilinearPooling, self).__init__()
    self.output_dim = output_dim
    self.sum_pool = sum_pool

    if rand_h_1 is None:
    rand_h_1 = torch.randint(output_dim, size = (input_dim1,))
    if rand_s_1 is None:
    rand_s_1 = 2 * torch.randint(2, size = (input_dim1,)) - 1
    if rand_h_2 is None:
    rand_h_2 = torch.randint(output_dim, size = (input_dim2,))
    if rand_s_2 is None:
    rand_s_2 = 2 * torch.randint(2, size = (input_dim2,)) - 1

    generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim), rand_h]], dim = -1).t(), rand_s, torch.Size([input_dim, output_dim])).to_dense()
    self.sparse_sketch_matrix1 = generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, output_dim)
    self.sparse_sketch_matrix2 = generate_sketch_matrix(rand_h_2, rand_s_2, input_dim2, output_dim)
    generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), torch.Size([input_dim, output_dim])).to_dense()
    self.sketch_matrix1 = nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim))
    self.sketch_matrix2 = nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim))

    def forward(self, bottom1, bottom2):
    sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().mm(self.sparse_sketch_matrix1).view(-1, output_dim)
    sketch_2 = bottom2.permute(0, 2, 3, 1).contiguous().mm(self.sparse_sketch_matrix2).view(-1, output_dim)
    fft1_real, fft1_imag = torch.rfft(sketch_1, 1).permute(2, 0, 1)
    fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1)
    fft_product = torch.stack([fft1_real * fft2_real - fft1_imag * fft2_imag, fft1_real * fft2_imag - fft1_imag * fft2_real], dim = -1)
    cbp = torch.irfft(fft_product).view(len(bottom1), bottom1.size(-2), bottom1.size(-1), self.output_dim) * self.output_dim
    return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)
    def forward(self, bottom1, bottom2):
    sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().matmul(self.sketch_matrix1).view(-1, self.output_dim)
    sketch_2 = bottom2.permute(0, 2, 3, 1).contiguous().matmul(self.sketch_matrix2).view(-1, self.output_dim)
    fft1_real, fft1_imag = torch.rfft(sketch_1, 1).permute(2, 0, 1)
    fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1)
    fft_product = torch.stack([fft1_real * fft2_real - fft1_imag * fft2_imag, fft1_real * fft2_imag - fft1_imag * fft2_real], dim = -1)
    cbp = torch.irfft(fft_product, 1, signal_sizes = (self.output_dim,)).view(len(bottom1), bottom1.size(-2), bottom1.size(-1), self.output_dim) * self.output_dim
    return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)
  19. vadimkantorov revised this gist Apr 11, 2018. 1 changed file with 4 additions and 4 deletions.
    8 changes: 4 additions & 4 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -8,13 +8,13 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1
    self.sum_pool = sum_pool

    if rand_h_1 is None:
    rand_h_1 = torch.randint(output_dim, sizes = (input_dim1,))
    rand_h_1 = torch.randint(output_dim, size = (input_dim1,))
    if rand_s_1 is None:
    rand_s_1 = 2 * torch.randint(2, sizes = (input_dim1,)) - 1
    rand_s_1 = 2 * torch.randint(2, size = (input_dim1,)) - 1
    if rand_h_2 is None:
    rand_h_2 = torch.randint(output_dim, sizes = (input_dim2,))
    rand_h_2 = torch.randint(output_dim, size = (input_dim2,))
    if rand_s_2 is None:
    rand_s_2 = 2 * torch.randint(2, sizes = (input_dim2,)) - 1
    rand_s_2 = 2 * torch.randint(2, size = (input_dim2,)) - 1

    generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim), rand_h]], dim = -1).t(), rand_s, torch.Size([input_dim, output_dim])).to_dense()
    self.sparse_sketch_matrix1 = generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, output_dim)
  20. vadimkantorov revised this gist Apr 11, 2018. 1 changed file with 6 additions and 8 deletions.
    14 changes: 6 additions & 8 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -4,23 +4,21 @@
    class CompactBilinearPooling(nn.Module):
    def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1 = None, rand_s_1 = None, rand_h_2 = None, rand_s_2 = None):
    super(CompactBilinearPooling, self).__init__()
    self.input_dim1 = input_dim1
    self.input_dim2 = input_dim2
    self.output_dim = output_dim
    self.sum_pool = sum_pool

    if rand_h_1 is None:
    rand_h_1 = torch.randint(output_dim, sizes = (self.input_dim1,))
    rand_h_1 = torch.randint(output_dim, sizes = (input_dim1,))
    if rand_s_1 is None:
    rand_s_1 = 2 * torch.randint(2, sizes = (self.input_dim1,)) - 1
    rand_s_1 = 2 * torch.randint(2, sizes = (input_dim1,)) - 1
    if rand_h_2 is None:
    rand_h_2 = torch.randint(output_dim, sizes = (self.input_dim2,))
    rand_h_2 = torch.randint(output_dim, sizes = (input_dim2,))
    if rand_s_2 is None:
    rand_s_2 = 2 * torch.randint(2, sizes = (self.input_dim2,)) - 1
    rand_s_2 = 2 * torch.randint(2, sizes = (input_dim2,)) - 1

    generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim), rand_h]], dim = -1).t(), rand_s, torch.Size([input_dim, output_dim])).to_dense()
    self.sparse_sketch_matrix1 = generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, self.output_dim)
    self.sparse_sketch_matrix2 = generate_sketch_matrix(rand_h_2, rand_s_2, input_dim2, self.output_dim)
    self.sparse_sketch_matrix1 = generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, output_dim)
    self.sparse_sketch_matrix2 = generate_sketch_matrix(rand_h_2, rand_s_2, input_dim2, output_dim)

    def forward(self, bottom1, bottom2):
    sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().mm(self.sparse_sketch_matrix1).view(-1, output_dim)
  21. vadimkantorov revised this gist Apr 11, 2018. 1 changed file with 4 additions and 8 deletions.
    12 changes: 4 additions & 8 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -18,8 +18,9 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1
    if rand_s_2 is None:
    rand_s_2 = 2 * torch.randint(2, sizes = (self.input_dim2,)) - 1

    self.sparse_sketch_matrix1 = self.generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, self.output_dim)
    self.sparse_sketch_matrix2 = self.generate_sketch_matrix(rand_h_2, rand_s_2, input_dim2, self.output_dim)
    generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim), rand_h]], dim = -1).t(), rand_s, torch.Size([input_dim, output_dim])).to_dense()
    self.sparse_sketch_matrix1 = generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, self.output_dim)
    self.sparse_sketch_matrix2 = generate_sketch_matrix(rand_h_2, rand_s_2, input_dim2, self.output_dim)

    def forward(self, bottom1, bottom2):
    sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().mm(self.sparse_sketch_matrix1).view(-1, output_dim)
    @@ -28,9 +29,4 @@ def forward(self, bottom1, bottom2):
    fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1)
    fft_product = torch.stack([fft1_real * fft2_real - fft1_imag * fft2_imag, fft1_real * fft2_imag - fft1_imag * fft2_real], dim = -1)
    cbp = torch.irfft(fft_product).view(len(bottom1), bottom1.size(-2), bottom1.size(-1), self.output_dim) * self.output_dim
    return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)

    @staticmethod
    def generate_sketch_matrix(rand_h, rand_s, input_dim, output_dim):
    indices = torch.stack([torch.arange(input_dim), rand_h]], dim = -1)
    return torch.sparse.FloatTensor(indices.t(), rand_s, torch.Size([input_dim, output_dim])).to_dense()
    return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)
  22. vadimkantorov revised this gist Apr 11, 2018. 1 changed file with 5 additions and 5 deletions.
    10 changes: 5 additions & 5 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -10,13 +10,13 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1
    self.sum_pool = sum_pool

    if rand_h_1 is None:
    rand_h_1 = torch.randint(output_dim, size = (self.input_dim1,))
    rand_h_1 = torch.randint(output_dim, sizes = (self.input_dim1,))
    if rand_s_1 is None:
    rand_s_1 = 2 * torch.randint(2, size = (self.input_dim1,)) - 1
    rand_s_1 = 2 * torch.randint(2, sizes = (self.input_dim1,)) - 1
    if rand_h_2 is None:
    rand_h_2 = torch.randint(output_dim, size = (self.input_dim2,))
    rand_h_2 = torch.randint(output_dim, sizes = (self.input_dim2,))
    if rand_s_2 is None:
    rand_s_2 = 2 * torch.randint(2, size = (self.input_dim2, )) - 1
    rand_s_2 = 2 * torch.randint(2, sizes = (self.input_dim2,)) - 1

    self.sparse_sketch_matrix1 = self.generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, self.output_dim)
    self.sparse_sketch_matrix2 = self.generate_sketch_matrix(rand_h_2, rand_s_2, input_dim2, self.output_dim)
    @@ -32,5 +32,5 @@ def forward(self, bottom1, bottom2):

    @staticmethod
    def generate_sketch_matrix(rand_h, rand_s, input_dim, output_dim):
    indices = np.concatenate((torch.arange(input_dim)[..., np.newaxis], rand_h[..., np.newaxis]), axis=1)
    indices = torch.stack([torch.arange(input_dim), rand_h]], dim = -1)
    return torch.sparse.FloatTensor(indices.t(), rand_s, torch.Size([input_dim, output_dim])).to_dense()
  23. vadimkantorov revised this gist Apr 11, 2018. 1 changed file with 12 additions and 13 deletions.
    25 changes: 12 additions & 13 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -18,20 +18,19 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1
    if rand_s_2 is None:
    rand_s_2 = 2 * torch.randint(2, size = (self.input_dim2, )) - 1

    self.sparse_sketch_matrix1 = self.generate_sketch_matrix(rand_h_1, rand_s_1, self.output_dim)
    self.sparse_sketch_matrix2 = self.generate_sketch_matrix(rand_h_2, rand_s_2, self.output_dim)
    self.sparse_sketch_matrix1 = self.generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, self.output_dim)
    self.sparse_sketch_matrix2 = self.generate_sketch_matrix(rand_h_2, rand_s_2, input_dim2, self.output_dim)

    def forward(self, bottom1, bottom2):
    assert bottom1.size(1) == self.input_dim1 and bottom2.size(1) == self.input_dim2
    batch_size, _, height, width = bottom1.size()

    sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1).mm(self.sparse_sketch_matrix1)
    sketch_2 = bottom2.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim2).mm(self.sparse_sketch_matrix2)

    sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().mm(self.sparse_sketch_matrix1).view(-1, output_dim)
    sketch_2 = bottom2.permute(0, 2, 3, 1).contiguous().mm(self.sparse_sketch_matrix2).view(-1, output_dim)
    fft1_real, fft1_imag = torch.rfft(sketch_1, 1).permute(2, 0, 1)
    fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1)
    fft_product_real = fft1_real * fft2_real - fft1_imag * fft2_imag
    fft_product_imag = fft1_real * fft2_imag - fft1_imag * fft2_real

    cbp = torch.irfft(torch.stack([fft_product_real, fft_product_imag], dim = -1).view(batch_size, height, width, self.output_dim) * self.output_dim
    return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)
    fft_product = torch.stack([fft1_real * fft2_real - fft1_imag * fft2_imag, fft1_real * fft2_imag - fft1_imag * fft2_real], dim = -1)
    cbp = torch.irfft(fft_product).view(len(bottom1), bottom1.size(-2), bottom1.size(-1), self.output_dim) * self.output_dim
    return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)

    @staticmethod
    def generate_sketch_matrix(rand_h, rand_s, input_dim, output_dim):
    indices = np.concatenate((torch.arange(input_dim)[..., np.newaxis], rand_h[..., np.newaxis]), axis=1)
    return torch.sparse.FloatTensor(indices.t(), rand_s, torch.Size([input_dim, output_dim])).to_dense()
  24. vadimkantorov revised this gist Apr 11, 2018. 1 changed file with 19 additions and 0 deletions.
    19 changes: 19 additions & 0 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -2,6 +2,25 @@
    import torch.nn as nn

    class CompactBilinearPooling(nn.Module):
    def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1 = None, rand_s_1 = None, rand_h_2 = None, rand_s_2 = None):
    super(CompactBilinearPooling, self).__init__()
    self.input_dim1 = input_dim1
    self.input_dim2 = input_dim2
    self.output_dim = output_dim
    self.sum_pool = sum_pool

    if rand_h_1 is None:
    rand_h_1 = torch.randint(output_dim, size = (self.input_dim1,))
    if rand_s_1 is None:
    rand_s_1 = 2 * torch.randint(2, size = (self.input_dim1,)) - 1
    if rand_h_2 is None:
    rand_h_2 = torch.randint(output_dim, size = (self.input_dim2,))
    if rand_s_2 is None:
    rand_s_2 = 2 * torch.randint(2, size = (self.input_dim2, )) - 1

    self.sparse_sketch_matrix1 = self.generate_sketch_matrix(rand_h_1, rand_s_1, self.output_dim)
    self.sparse_sketch_matrix2 = self.generate_sketch_matrix(rand_h_2, rand_s_2, self.output_dim)

    def forward(self, bottom1, bottom2):
    assert bottom1.size(1) == self.input_dim1 and bottom2.size(1) == self.input_dim2
    batch_size, _, height, width = bottom1.size()
  25. vadimkantorov revised this gist Apr 11, 2018. 1 changed file with 4 additions and 13 deletions.
    17 changes: 4 additions & 13 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -6,22 +6,13 @@ def forward(self, bottom1, bottom2):
    assert bottom1.size(1) == self.input_dim1 and bottom2.size(1) == self.input_dim2
    batch_size, _, height, width = bottom1.size()

    bottom1_flat = bottom1.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1)
    bottom2_flat = bottom2.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim2)

    sketch_1 = bottom1_flat.mm(self.sparse_sketch_matrix1)
    sketch_2 = bottom2_flat.mm(self.sparse_sketch_matrix2)
    sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1).mm(self.sparse_sketch_matrix1)
    sketch_2 = bottom2.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim2).mm(self.sparse_sketch_matrix2)

    fft1_real, fft1_imag = torch.rfft(sketch_1, 1).permute(2, 0, 1)
    fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1)

    fft_product_real = fft1_real * fft2_real - fft1_imag * fft2_imag
    fft_product_imag = fft1_real * fft2_imag - fft1_imag * fft2_real

    cbp_flat = torch.irfft(torch.stack([fft_product_real, fft_product_imag], dim = -1)
    cbp = cbp_flat.view(batch_size, height, width, self.output_dim) * self.output_dim

    if self.sum_pool:
    return cbp.sum(dim = 1).sum(dim = 1)

    return cbp.permute(0, 3, 1, 2)
    cbp = torch.irfft(torch.stack([fft_product_real, fft_product_imag], dim = -1).view(batch_size, height, width, self.output_dim) * self.output_dim
    return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2)
  26. vadimkantorov revised this gist Apr 11, 2018. 1 changed file with 23 additions and 1 deletion.
    24 changes: 23 additions & 1 deletion compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -2,4 +2,26 @@
    import torch.nn as nn

    class CompactBilinearPooling(nn.Module):
    pass
    def forward(self, bottom1, bottom2):
    assert bottom1.size(1) == self.input_dim1 and bottom2.size(1) == self.input_dim2
    batch_size, _, height, width = bottom1.size()

    bottom1_flat = bottom1.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1)
    bottom2_flat = bottom2.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim2)

    sketch_1 = bottom1_flat.mm(self.sparse_sketch_matrix1)
    sketch_2 = bottom2_flat.mm(self.sparse_sketch_matrix2)

    fft1_real, fft1_imag = torch.rfft(sketch_1, 1).permute(2, 0, 1)
    fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1)

    fft_product_real = fft1_real * fft2_real - fft1_imag * fft2_imag
    fft_product_imag = fft1_real * fft2_imag - fft1_imag * fft2_real

    cbp_flat = torch.irfft(torch.stack([fft_product_real, fft_product_imag], dim = -1)
    cbp = cbp_flat.view(batch_size, height, width, self.output_dim) * self.output_dim

    if self.sum_pool:
    return cbp.sum(dim = 1).sum(dim = 1)

    return cbp.permute(0, 3, 1, 2)
  27. vadimkantorov created this gist Apr 11, 2018.
    5 changes: 5 additions & 0 deletions compact_bilinear_pooling.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,5 @@
    import torch
    import torch.nn as nn

    class CompactBilinearPooling(nn.Module):
    pass