import torch from torch import nn from typing import Optional class AdaLayerNorm(nn.Module): """ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes Args: num_embeddings (int): Number of embeddings. embedding_dim (int): Dimension of the embeddings. """ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.dim = embedding_dim self.scale = nn.Embedding( num_embeddings=num_embeddings, embedding_dim=embedding_dim ) self.shift = nn.Embedding( num_embeddings=num_embeddings, embedding_dim=embedding_dim ) torch.nn.init.ones_(self.scale.weight) torch.nn.init.zeros_(self.shift.weight) def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: scale = self.scale(cond_embedding_id) shift = self.shift(cond_embedding_id) x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) x = x * scale + shift return x class ISTFT(nn.Module): """ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. See issue: https://github.com/pytorch/pytorch/issues/62323 Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. Args: n_fft (int): Size of Fourier transform. hop_length (int): The distance between neighboring sliding window frames. win_length (int): The size of window frame and STFT filter. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". """ def __init__( self, n_fft: int, hop_length: int, win_length: int, padding: str = "same" ): super().__init__() if padding not in ["center", "same"]: raise ValueError("Padding must be 'center' or 'same'.") self.padding = padding self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length window = torch.hann_window(win_length) self.register_buffer("window", window) def forward(self, spec: torch.Tensor) -> torch.Tensor: """ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. Args: spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, N is the number of frequency bins, and T is the number of time frames. Returns: Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. """ if self.padding == "center": # Fallback to pytorch native implementation return torch.istft( spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True, ) elif self.padding == "same": pad = (self.win_length - self.hop_length) // 2 else: raise ValueError("Padding must be 'center' or 'same'.") assert spec.dim() == 3, "Expected a 3D tensor as input" B, N, T = spec.shape # Inverse FFT ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") ifft = ifft * self.window[None, :, None] # Overlap and Add output_size = (T - 1) * self.hop_length + self.win_length y = torch.nn.functional.fold( ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), )[:, 0, 0, :] # Window envelope window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) window_envelope = torch.nn.functional.fold( window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), ).squeeze() # Normalize norm_indexes = window_envelope > 1e-11 y[:, norm_indexes] = y[:, norm_indexes] / window_envelope[norm_indexes] return y class ConvNeXtBlock(nn.Module): """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. Args: dim (int): Number of input channels. intermediate_dim (int): Dimensionality of the intermediate layer. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. Defaults to None. adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. None means non-conditional LayerNorm. Defaults to None. """ def __init__( self, dim: int, intermediate_dim: int, layer_scale_init_value: Optional[float] = None, adanorm_num_embeddings: Optional[int] = None, ): super().__init__() self.dwconv = nn.Conv1d( dim, dim, kernel_size=7, padding=3, groups=dim ) # depthwise conv self.adanorm = adanorm_num_embeddings is not None if adanorm_num_embeddings: self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) else: self.norm = nn.LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear( dim, intermediate_dim ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) self.gamma = ( nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None ) def forward( self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None ) -> torch.Tensor: residual = x x = self.dwconv(x) x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) if self.adanorm: assert cond_embedding_id is not None x = self.norm(x, cond_embedding_id) else: x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) x = residual + x return x class Generator(torch.nn.Module): def __init__( self, feature_dim: int = 80, dim: int = 512, n_fft: int = 1024, hop_length: int = 256, intermediate_dim: int = 1536, num_layers: int = 8, padding: str = "same", layer_scale_init_value: Optional[float] = None, adanorm_num_embeddings: Optional[int] = None, ): super(Generator, self).__init__() self.feature_dim = feature_dim self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3) self.adanorm = adanorm_num_embeddings is not None if adanorm_num_embeddings: self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) else: self.norm = nn.LayerNorm(dim, eps=1e-6) layer_scale_init_value = layer_scale_init_value or 1 / num_layers self.convnext = nn.ModuleList( [ ConvNeXtBlock( dim=dim, intermediate_dim=intermediate_dim, layer_scale_init_value=layer_scale_init_value, adanorm_num_embeddings=adanorm_num_embeddings, ) for _ in range(num_layers) ] ) self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) self.apply(self._init_weights) self.out_proj = torch.nn.Linear(dim, n_fft + 2) self.istft = ISTFT( n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding ) def _init_weights(self, m): if isinstance(m, (nn.Conv1d, nn.Linear)): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: bandwidth_id = kwargs.get("bandwidth_id", None) x = self.embed(x) if self.adanorm: assert bandwidth_id is not None x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) else: x = self.norm(x.transpose(1, 2)) x = x.transpose(1, 2) for conv_block in self.convnext: x = conv_block(x, cond_embedding_id=bandwidth_id) x = self.final_layer_norm(x.transpose(1, 2)) x = self.out_proj(x).transpose(1, 2) mag, p = x.chunk(2, dim=1) mag = torch.exp(mag) mag = torch.clip( mag, max=1e2 ) # safeguard to prevent excessively large magnitudes x = torch.cos(p) y = torch.sin(p) S = mag * (x + 1j * y) audio = self.istft(S) return audio