Skip to content

Instantly share code, notes, and snippets.

@NormXU
Created January 9, 2025 14:02
Show Gist options
  • Save NormXU/913e9025965475033b55c09f93aa7d6d to your computer and use it in GitHub Desktop.
Save NormXU/913e9025965475033b55c09f93aa7d6d to your computer and use it in GitHub Desktop.
FourierFeatures
class FourierFeatures(nn.Module):
""" Copied from https://github.com/NVIDIA/Cosmos/blob/c47b35b7618a6e263556f3e3fb7cfba3705c08a5/cosmos1/models/diffusion/module/blocks.py
Related Research: https://arxiv.org/pdf/2006.10739
Implements a layer that generates Fourier features from input tensors, based on randomly sampled
frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems.
[B] -> [B, D]
Parameters:
num_channels (int): The number of Fourier features to generate.
bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1.
normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize
the variance of the features. Defaults to False.
Example:
>>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True)
>>> x = torch.randn(10, 256) # Example input tensor
>>> output = layer(x)
>>> print(output.shape) # Expected shape: (10, 256)
"""
def __init__(self, num_channels, bandwidth=1, normalize=False):
super().__init__()
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
self.gain = np.sqrt(2) if normalize else 1
def forward(self, x, gain: float = 1.0):
"""
Apply the Fourier feature transformation to the input tensor.
Args:
x (torch.Tensor): The input tensor.
gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1.
Returns:
torch.Tensor: The transformed tensor, with Fourier features applied.
"""
in_dtype = x.dtype
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
x = x.cos().mul(self.gain * gain).to(in_dtype)
return x
@NormXU
Copy link
Author

NormXU commented Jan 9, 2025

@NormXU
Copy link
Author

NormXU commented Jan 9, 2025

FourierFeature In Math Equations:

$$\mathbf{F} = \cos\left( \mathbf{c} \otimes \mathbf{f} + \mathbf{p} \right) \times \sqrt{2}$$

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment