Created
January 9, 2025 14:02
-
-
Save NormXU/913e9025965475033b55c09f93aa7d6d to your computer and use it in GitHub Desktop.
FourierFeatures
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class FourierFeatures(nn.Module): | |
| """ Copied from https://github.com/NVIDIA/Cosmos/blob/c47b35b7618a6e263556f3e3fb7cfba3705c08a5/cosmos1/models/diffusion/module/blocks.py | |
| Related Research: https://arxiv.org/pdf/2006.10739 | |
| Implements a layer that generates Fourier features from input tensors, based on randomly sampled | |
| frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems. | |
| [B] -> [B, D] | |
| Parameters: | |
| num_channels (int): The number of Fourier features to generate. | |
| bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1. | |
| normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize | |
| the variance of the features. Defaults to False. | |
| Example: | |
| >>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True) | |
| >>> x = torch.randn(10, 256) # Example input tensor | |
| >>> output = layer(x) | |
| >>> print(output.shape) # Expected shape: (10, 256) | |
| """ | |
| def __init__(self, num_channels, bandwidth=1, normalize=False): | |
| super().__init__() | |
| self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True) | |
| self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True) | |
| self.gain = np.sqrt(2) if normalize else 1 | |
| def forward(self, x, gain: float = 1.0): | |
| """ | |
| Apply the Fourier feature transformation to the input tensor. | |
| Args: | |
| x (torch.Tensor): The input tensor. | |
| gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1. | |
| Returns: | |
| torch.Tensor: The transformed tensor, with Fourier features applied. | |
| """ | |
| in_dtype = x.dtype | |
| x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32)) | |
| x = x.cos().mul(self.gain * gain).to(in_dtype) | |
| return x |
FourierFeature In Math Equations:
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Related Blogs: