Last active
September 22, 2021 07:51
-
-
Save vadimkantorov/d9b56f9b85f1f4ce59ffecf893a1581a to your computer and use it in GitHub Desktop.
Revisions
-
vadimkantorov revised this gist
Sep 22, 2021 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -3,7 +3,7 @@ # [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062 # [3] Fast and Scalable Polynomial Kernels via Explicit Feature Maps, Pham and Pagh, https://chbrown.github.io/kdd-2013-usb/kdd/p239.pdf # [4] Fastfood — Approximating Kernel Expansions in Loglinear Time, Le et al., https://arxiv.org/abs/1408.3060 # [5] Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling # TODO: migrate to use of new native complex64 types # TODO: change strided x coo matmul to torch.matmul(): M[sparse_coo] @ M[strided] -> M[strided] -
vadimkantorov revised this gist
Sep 22, 2021 . 1 changed file with 3 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -5,6 +5,9 @@ # [4] Fastfood — Approximating Kernel Expansions in Loglinear Time, Le et al., https://arxiv.org/abs/1408.3060 # Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling # TODO: migrate to use of new native complex64 types # TODO: change strided x coo matmul to torch.matmul(): M[sparse_coo] @ M[strided] -> M[strided] import torch class CompactBilinearPooling(torch.nn.Module): -
vadimkantorov revised this gist
Mar 19, 2021 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -12,7 +12,7 @@ def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True): super().__init__() self.out_channels = out_channels self.sum_pool = sum_pool generate_tensor_sketch = lambda rand_h, rand_s, in_channels, out_channels: torch.sparse.FloatTensor(torch.stack([torch.arange(in_channels), rand_h]), rand_s, [in_channels, out_channels]).to_dense() self.tenosr_sketch1 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels1,)), 2 * torch.randint(2, size = (in_channels1,), dtype = torch.float32) - 1, in_channels1, out_channels), requires_grad = False) self.tensor_sketch2 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels2,)), 2 * torch.randint(2, size = (in_channels2,), dtype = torch.float32) - 1, in_channels2, out_channels), requires_grad = False) -
vadimkantorov revised this gist
Apr 28, 2020 . 1 changed file with 1 addition and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -2,6 +2,7 @@ # [1] Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding, Fukui et al., https://arxiv.org/abs/1606.01847 # [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062 # [3] Fast and Scalable Polynomial Kernels via Explicit Feature Maps, Pham and Pagh, https://chbrown.github.io/kdd-2013-usb/kdd/p239.pdf # [4] Fastfood — Approximating Kernel Expansions in Loglinear Time, Le et al., https://arxiv.org/abs/1408.3060 # Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling import torch -
vadimkantorov revised this gist
Apr 28, 2020 . 1 changed file with 0 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -2,7 +2,6 @@ # [1] Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding, Fukui et al., https://arxiv.org/abs/1606.01847 # [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062 # [3] Fast and Scalable Polynomial Kernels via Explicit Feature Maps, Pham and Pagh, https://chbrown.github.io/kdd-2013-usb/kdd/p239.pdf # Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling import torch -
vadimkantorov revised this gist
Apr 28, 2020 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -19,7 +19,7 @@ def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True): def forward(self, x1, x2): fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.tensor_sketch1), signal_ndim = 1) fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.tensor_sketch2), signal_ndim = 1) # torch.rfft does not support yet torch.complex64 outputs, so we do complex product manually fft_complex_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1) cbp = torch.irfft(fft_complex_product, signal_ndim = 1, signal_sizes = (self.out_channels, )) * self.out_channels return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2) -
vadimkantorov revised this gist
Apr 28, 2020 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -9,7 +9,7 @@ class CompactBilinearPooling(torch.nn.Module): def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True): super().__init__() self.out_channels = out_channels self.sum_pool = sum_pool generate_tensor_sketch = lambda rand_h, rand_s, in_channels, out_channels: torch.sparse.FloatTensor(torch.stack([torch.arange(in_features), rand_h]), rand_s, [in_channels, out_channels]).to_dense() -
vadimkantorov revised this gist
Apr 28, 2020 . 1 changed file with 2 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -13,8 +13,8 @@ def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True): self.out_channels = out_channels self.sum_pool = sum_pool generate_tensor_sketch = lambda rand_h, rand_s, in_channels, out_channels: torch.sparse.FloatTensor(torch.stack([torch.arange(in_features), rand_h]), rand_s, [in_channels, out_channels]).to_dense() self.tenosr_sketch1 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels1,)), 2 * torch.randint(2, size = (in_channels1,), dtype = torch.float32) - 1, in_channels1, out_channels), requires_grad = False) self.tensor_sketch2 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels2,)), 2 * torch.randint(2, size = (in_channels2,), dtype = torch.float32) - 1, in_channels2, out_channels), requires_grad = False) def forward(self, x1, x2): fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.tensor_sketch1), signal_ndim = 1) -
vadimkantorov revised this gist
Apr 28, 2020 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -3,7 +3,7 @@ # [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062 # [3] Fast and Scalable Polynomial Kernels via Explicit Feature Maps, Pham and Pagh, https://chbrown.github.io/kdd-2013-usb/kdd/p239.pdf # Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling import torch -
vadimkantorov revised this gist
Apr 28, 2020 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -17,7 +17,7 @@ def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True): self.tensor_sketch2 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels2,)), 2 * torch.randint(2, size = (in_channels2,)) - 1, in_channels2, out_channels, dtype = torch.float32), requires_grad = False) def forward(self, x1, x2): fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.tensor_sketch1), signal_ndim = 1) fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.tensor_sketch2), signal_ndim = 1) # torch.rfft does not support yet returning torch.complex64 outputs, so we do complex product manually fft_complex_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1) -
vadimkantorov revised this gist
Apr 28, 2020 . 1 changed file with 11 additions and 8 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,22 +1,25 @@ # References: # [1] Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding, Fukui et al., https://arxiv.org/abs/1606.01847 # [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062 # [3] Fast and Scalable Polynomial Kernels via Explicit Feature Maps, Pham and Pagh, https://chbrown.github.io/kdd-2013-usb/kdd/p239.pdf # Original implementation in Caffe: https://github.com/gy20073/compact_bilinear_pooling/blob/master/caffe-20160312/src/caffe/layers/compact_bilinear_layer.cpp import torch class CompactBilinearPooling(torch.nn.Module): def __init__(self, in_channels1, in_channels2, out_channels, sum_pool = True): super(CompactBilinearPooling, self).__init__() self.out_channels = out_channels self.sum_pool = sum_pool generate_tensor_sketch = lambda rand_h, rand_s, in_channels, out_channels: torch.sparse.FloatTensor(torch.stack([torch.arange(in_features), rand_h]), rand_s, [in_channels, out_channels]).to_dense() self.tenosr_sketch1 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels1,)), 2 * torch.randint(2, size = (in_channels1,)) - 1, in_channels1, out_channels, dtype = torch.float32), requires_grad = False) self.tensor_sketch2 = torch.nn.Parameter(generate_tensor_sketch(torch.randint(out_channels, size = (in_channels2,)), 2 * torch.randint(2, size = (in_channels2,)) - 1, in_channels2, out_channels, dtype = torch.float32), requires_grad = False) def forward(self, x1, x2): fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.tenosr_sketch1), signal_ndim = 1) fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.tensor_sketch2), signal_ndim = 1) # torch.rfft does not support yet returning torch.complex64 outputs, so we do complex product manually fft_complex_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1) cbp = torch.irfft(fft_complex_product, signal_ndim = 1, signal_sizes = (self.out_channels, )) * self.out_channels return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2) -
vadimkantorov revised this gist
Apr 28, 2020 . 1 changed file with 7 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,3 +1,7 @@ # References: # [1] Multimodal Compact Bilinear Pooling for Visual Question Answering and Visual Grounding, Fukui et al., https://arxiv.org/abs/1606.01847 # [2] Compact Bilinear Pooling, Gao et al., https://arxiv.org/abs/1511.06062 import torch class CompactBilinearPooling(torch.nn.Module): @@ -12,6 +16,7 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True): def forward(self, x1, x2): fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch1), signal_ndim = 1) fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch2), signal_ndim = 1) # torch.rfft does not support yet returning torch.complex64 outputs, so we do complex product manually fft_complex_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1) cbp = torch.irfft(fft_complex_product, signal_ndim = 1, signal_sizes = (self.output_dim, )) * self.output_dim return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2) -
vadimkantorov revised this gist
Mar 15, 2019 . 1 changed file with 4 additions and 4 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -6,12 +6,12 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True): self.output_dim = output_dim self.sum_pool = sum_pool generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), [input_dim, output_dim]).to_dense() self.sketch1 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim), requires_grad = False) self.sketch2 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim), requires_grad = False) def forward(self, x1, x2): fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch1), signal_ndim = 1) fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch2), signal_ndim = 1) fft_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1) cbp = torch.irfft(fft_product, signal_ndim = 1, signal_sizes = (self.output_dim, )) * self.output_dim return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2) -
vadimkantorov revised this gist
Mar 15, 2019 . 1 changed file with 6 additions and 6 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -6,12 +6,12 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True): self.output_dim = output_dim self.sum_pool = sum_pool generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), [input_dim, output_dim]).to_dense() self.sketch_matrix1 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim), requires_grad = False) self.sketch_matrix2 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim), requires_grad = False) def forward(self, x1, x2): fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch_matrix1), signal_ndim = 1) fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch_matrix2), signal_ndim = 1) fft_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1) cbp = torch.irfft(fft_product, signal_ndim = 1, signal_sizes = (self.output_dim, )) * self.output_dim return cbp.sum(dim = [1, 2]) if self.sum_pool else cbp.permute(0, 3, 1, 2) -
vadimkantorov revised this gist
Apr 12, 2018 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -13,5 +13,5 @@ def forward(self, x1, x2): fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch_matrix1), 1) fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch_matrix2), 1) fft_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1) cbp = torch.irfft(fft_product, 1, signal_sizes = (self.output_dim,)) * self.output_dim return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2) -
vadimkantorov revised this gist
Apr 12, 2018 . 1 changed file with 4 additions and 4 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -6,12 +6,12 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True): self.output_dim = output_dim self.sum_pool = sum_pool generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), [input_dim, output_dim]).to_dense() self.sketch_matrix1 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim)) self.sketch_matrix2 = torch.nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim)) def forward(self, x1, x2): fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch_matrix1), 1) fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch_matrix2), 1) fft_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1) cbp = torch.irfft(fft_product, 1, signal_sizes = (self.output_dim,)).view(len(x1), x1.size(-2), x1.size(-1), -1) * self.output_dim return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2) -
vadimkantorov revised this gist
Apr 11, 2018 . 1 changed file with 8 additions and 12 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,21 +1,17 @@ import torch class CompactBilinearPooling(torch.nn.Module): def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True): super(CompactBilinearPooling, self).__init__() self.output_dim = output_dim self.sum_pool = sum_pool generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), [input_dim, output_dim]).to_dense() self.sketch_matrix1 = nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim)) self.sketch_matrix2 = nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim)) def forward(self, x1, x2): fft1 = torch.rfft(x1.permute(0, 2, 3, 1).matmul(self.sketch_matrix1).view(-1, self.output_dim), 1) fft2 = torch.rfft(x2.permute(0, 2, 3, 1).matmul(self.sketch_matrix2).view(-1, self.output_dim), 1) fft_product = torch.stack([fft1[..., 0] * fft2[..., 0] - fft1[..., 1] * fft2[..., 1], fft1[..., 0] * fft2[..., 1] + fft1[..., 1] * fft2[..., 0]], dim = -1) cbp = torch.irfft(fft_product, 1, signal_sizes = (self.output_dim,)).view(len(x1), x1.size(-2), x1.size(-1), -1) * self.output_dim return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2) -
vadimkantorov revised this gist
Apr 11, 2018 . 1 changed file with 15 additions and 24 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -2,29 +2,20 @@ import torch.nn as nn class CompactBilinearPooling(nn.Module): def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True): super(CompactBilinearPooling, self).__init__() self.output_dim = output_dim self.sum_pool = sum_pool generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim, out = torch.LongTensor()), rand_h.long()]), rand_s.float(), torch.Size([input_dim, output_dim])).to_dense() self.sketch_matrix1 = nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim1,)), 2 * torch.randint(2, size = (input_dim1,)) - 1, input_dim1, output_dim)) self.sketch_matrix2 = nn.Parameter(generate_sketch_matrix(torch.randint(output_dim, size = (input_dim2,)), 2 * torch.randint(2, size = (input_dim2,)) - 1, input_dim2, output_dim)) def forward(self, bottom1, bottom2): sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().matmul(self.sketch_matrix1).view(-1, self.output_dim) sketch_2 = bottom2.permute(0, 2, 3, 1).contiguous().matmul(self.sketch_matrix2).view(-1, self.output_dim) fft1_real, fft1_imag = torch.rfft(sketch_1, 1).permute(2, 0, 1) fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1) fft_product = torch.stack([fft1_real * fft2_real - fft1_imag * fft2_imag, fft1_real * fft2_imag - fft1_imag * fft2_real], dim = -1) cbp = torch.irfft(fft_product, 1, signal_sizes = (self.output_dim,)).view(len(bottom1), bottom1.size(-2), bottom1.size(-1), self.output_dim) * self.output_dim return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2) -
vadimkantorov revised this gist
Apr 11, 2018 . 1 changed file with 4 additions and 4 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -8,13 +8,13 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1 self.sum_pool = sum_pool if rand_h_1 is None: rand_h_1 = torch.randint(output_dim, size = (input_dim1,)) if rand_s_1 is None: rand_s_1 = 2 * torch.randint(2, size = (input_dim1,)) - 1 if rand_h_2 is None: rand_h_2 = torch.randint(output_dim, size = (input_dim2,)) if rand_s_2 is None: rand_s_2 = 2 * torch.randint(2, size = (input_dim2,)) - 1 generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim), rand_h]], dim = -1).t(), rand_s, torch.Size([input_dim, output_dim])).to_dense() self.sparse_sketch_matrix1 = generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, output_dim) -
vadimkantorov revised this gist
Apr 11, 2018 . 1 changed file with 6 additions and 8 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -4,23 +4,21 @@ class CompactBilinearPooling(nn.Module): def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1 = None, rand_s_1 = None, rand_h_2 = None, rand_s_2 = None): super(CompactBilinearPooling, self).__init__() self.output_dim = output_dim self.sum_pool = sum_pool if rand_h_1 is None: rand_h_1 = torch.randint(output_dim, sizes = (input_dim1,)) if rand_s_1 is None: rand_s_1 = 2 * torch.randint(2, sizes = (input_dim1,)) - 1 if rand_h_2 is None: rand_h_2 = torch.randint(output_dim, sizes = (input_dim2,)) if rand_s_2 is None: rand_s_2 = 2 * torch.randint(2, sizes = (input_dim2,)) - 1 generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim), rand_h]], dim = -1).t(), rand_s, torch.Size([input_dim, output_dim])).to_dense() self.sparse_sketch_matrix1 = generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, output_dim) self.sparse_sketch_matrix2 = generate_sketch_matrix(rand_h_2, rand_s_2, input_dim2, output_dim) def forward(self, bottom1, bottom2): sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().mm(self.sparse_sketch_matrix1).view(-1, output_dim) -
vadimkantorov revised this gist
Apr 11, 2018 . 1 changed file with 4 additions and 8 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -18,8 +18,9 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1 if rand_s_2 is None: rand_s_2 = 2 * torch.randint(2, sizes = (self.input_dim2,)) - 1 generate_sketch_matrix = lambda rand_h, rand_s, input_dim, output_dim: torch.sparse.FloatTensor(torch.stack([torch.arange(input_dim), rand_h]], dim = -1).t(), rand_s, torch.Size([input_dim, output_dim])).to_dense() self.sparse_sketch_matrix1 = generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, self.output_dim) self.sparse_sketch_matrix2 = generate_sketch_matrix(rand_h_2, rand_s_2, input_dim2, self.output_dim) def forward(self, bottom1, bottom2): sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().mm(self.sparse_sketch_matrix1).view(-1, output_dim) @@ -28,9 +29,4 @@ def forward(self, bottom1, bottom2): fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1) fft_product = torch.stack([fft1_real * fft2_real - fft1_imag * fft2_imag, fft1_real * fft2_imag - fft1_imag * fft2_real], dim = -1) cbp = torch.irfft(fft_product).view(len(bottom1), bottom1.size(-2), bottom1.size(-1), self.output_dim) * self.output_dim return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2) -
vadimkantorov revised this gist
Apr 11, 2018 . 1 changed file with 5 additions and 5 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -10,13 +10,13 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1 self.sum_pool = sum_pool if rand_h_1 is None: rand_h_1 = torch.randint(output_dim, sizes = (self.input_dim1,)) if rand_s_1 is None: rand_s_1 = 2 * torch.randint(2, sizes = (self.input_dim1,)) - 1 if rand_h_2 is None: rand_h_2 = torch.randint(output_dim, sizes = (self.input_dim2,)) if rand_s_2 is None: rand_s_2 = 2 * torch.randint(2, sizes = (self.input_dim2,)) - 1 self.sparse_sketch_matrix1 = self.generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, self.output_dim) self.sparse_sketch_matrix2 = self.generate_sketch_matrix(rand_h_2, rand_s_2, input_dim2, self.output_dim) @@ -32,5 +32,5 @@ def forward(self, bottom1, bottom2): @staticmethod def generate_sketch_matrix(rand_h, rand_s, input_dim, output_dim): indices = torch.stack([torch.arange(input_dim), rand_h]], dim = -1) return torch.sparse.FloatTensor(indices.t(), rand_s, torch.Size([input_dim, output_dim])).to_dense() -
vadimkantorov revised this gist
Apr 11, 2018 . 1 changed file with 12 additions and 13 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -18,20 +18,19 @@ def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1 if rand_s_2 is None: rand_s_2 = 2 * torch.randint(2, size = (self.input_dim2, )) - 1 self.sparse_sketch_matrix1 = self.generate_sketch_matrix(rand_h_1, rand_s_1, input_dim1, self.output_dim) self.sparse_sketch_matrix2 = self.generate_sketch_matrix(rand_h_2, rand_s_2, input_dim2, self.output_dim) def forward(self, bottom1, bottom2): sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().mm(self.sparse_sketch_matrix1).view(-1, output_dim) sketch_2 = bottom2.permute(0, 2, 3, 1).contiguous().mm(self.sparse_sketch_matrix2).view(-1, output_dim) fft1_real, fft1_imag = torch.rfft(sketch_1, 1).permute(2, 0, 1) fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1) fft_product = torch.stack([fft1_real * fft2_real - fft1_imag * fft2_imag, fft1_real * fft2_imag - fft1_imag * fft2_real], dim = -1) cbp = torch.irfft(fft_product).view(len(bottom1), bottom1.size(-2), bottom1.size(-1), self.output_dim) * self.output_dim return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2) @staticmethod def generate_sketch_matrix(rand_h, rand_s, input_dim, output_dim): indices = np.concatenate((torch.arange(input_dim)[..., np.newaxis], rand_h[..., np.newaxis]), axis=1) return torch.sparse.FloatTensor(indices.t(), rand_s, torch.Size([input_dim, output_dim])).to_dense() -
vadimkantorov revised this gist
Apr 11, 2018 . 1 changed file with 19 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -2,6 +2,25 @@ import torch.nn as nn class CompactBilinearPooling(nn.Module): def __init__(self, input_dim1, input_dim2, output_dim, sum_pool = True, rand_h_1 = None, rand_s_1 = None, rand_h_2 = None, rand_s_2 = None): super(CompactBilinearPooling, self).__init__() self.input_dim1 = input_dim1 self.input_dim2 = input_dim2 self.output_dim = output_dim self.sum_pool = sum_pool if rand_h_1 is None: rand_h_1 = torch.randint(output_dim, size = (self.input_dim1,)) if rand_s_1 is None: rand_s_1 = 2 * torch.randint(2, size = (self.input_dim1,)) - 1 if rand_h_2 is None: rand_h_2 = torch.randint(output_dim, size = (self.input_dim2,)) if rand_s_2 is None: rand_s_2 = 2 * torch.randint(2, size = (self.input_dim2, )) - 1 self.sparse_sketch_matrix1 = self.generate_sketch_matrix(rand_h_1, rand_s_1, self.output_dim) self.sparse_sketch_matrix2 = self.generate_sketch_matrix(rand_h_2, rand_s_2, self.output_dim) def forward(self, bottom1, bottom2): assert bottom1.size(1) == self.input_dim1 and bottom2.size(1) == self.input_dim2 batch_size, _, height, width = bottom1.size() -
vadimkantorov revised this gist
Apr 11, 2018 . 1 changed file with 4 additions and 13 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -6,22 +6,13 @@ def forward(self, bottom1, bottom2): assert bottom1.size(1) == self.input_dim1 and bottom2.size(1) == self.input_dim2 batch_size, _, height, width = bottom1.size() sketch_1 = bottom1.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1).mm(self.sparse_sketch_matrix1) sketch_2 = bottom2.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim2).mm(self.sparse_sketch_matrix2) fft1_real, fft1_imag = torch.rfft(sketch_1, 1).permute(2, 0, 1) fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1) fft_product_real = fft1_real * fft2_real - fft1_imag * fft2_imag fft_product_imag = fft1_real * fft2_imag - fft1_imag * fft2_real cbp = torch.irfft(torch.stack([fft_product_real, fft_product_imag], dim = -1).view(batch_size, height, width, self.output_dim) * self.output_dim return cbp.sum(dim = 1).sum(dim = 1) if self.sum_pool else cbp.permute(0, 3, 1, 2) -
vadimkantorov revised this gist
Apr 11, 2018 . 1 changed file with 23 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -2,4 +2,26 @@ import torch.nn as nn class CompactBilinearPooling(nn.Module): def forward(self, bottom1, bottom2): assert bottom1.size(1) == self.input_dim1 and bottom2.size(1) == self.input_dim2 batch_size, _, height, width = bottom1.size() bottom1_flat = bottom1.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1) bottom2_flat = bottom2.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim2) sketch_1 = bottom1_flat.mm(self.sparse_sketch_matrix1) sketch_2 = bottom2_flat.mm(self.sparse_sketch_matrix2) fft1_real, fft1_imag = torch.rfft(sketch_1, 1).permute(2, 0, 1) fft2_real, fft2_imag = torch.rfft(sketch_2, 1).permute(2, 0, 1) fft_product_real = fft1_real * fft2_real - fft1_imag * fft2_imag fft_product_imag = fft1_real * fft2_imag - fft1_imag * fft2_real cbp_flat = torch.irfft(torch.stack([fft_product_real, fft_product_imag], dim = -1) cbp = cbp_flat.view(batch_size, height, width, self.output_dim) * self.output_dim if self.sum_pool: return cbp.sum(dim = 1).sum(dim = 1) return cbp.permute(0, 3, 1, 2) -
vadimkantorov created this gist
Apr 11, 2018 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,5 @@ import torch import torch.nn as nn class CompactBilinearPooling(nn.Module): pass