class FourierFeatures(nn.Module): """ Copied from https://github.com/NVIDIA/Cosmos/blob/c47b35b7618a6e263556f3e3fb7cfba3705c08a5/cosmos1/models/diffusion/module/blocks.py Related Research: https://arxiv.org/pdf/2006.10739 Implements a layer that generates Fourier features from input tensors, based on randomly sampled frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems. [B] -> [B, D] Parameters: num_channels (int): The number of Fourier features to generate. bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1. normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize the variance of the features. Defaults to False. Example: >>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True) >>> x = torch.randn(10, 256) # Example input tensor >>> output = layer(x) >>> print(output.shape) # Expected shape: (10, 256) """ def __init__(self, num_channels, bandwidth=1, normalize=False): super().__init__() self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True) self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True) self.gain = np.sqrt(2) if normalize else 1 def forward(self, x, gain: float = 1.0): """ Apply the Fourier feature transformation to the input tensor. Args: x (torch.Tensor): The input tensor. gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1. Returns: torch.Tensor: The transformed tensor, with Fourier features applied. """ in_dtype = x.dtype x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32)) x = x.cos().mul(self.gain * gain).to(in_dtype) return x