mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-14 20:42:22 +00:00
258 lines
8.9 KiB
Python
258 lines
8.9 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
from typing import Optional
|
|
|
|
|
|
class AdaLayerNorm(nn.Module):
|
|
"""
|
|
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
|
|
|
Args:
|
|
num_embeddings (int): Number of embeddings.
|
|
embedding_dim (int): Dimension of the embeddings.
|
|
"""
|
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.dim = embedding_dim
|
|
self.scale = nn.Embedding(
|
|
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
|
)
|
|
self.shift = nn.Embedding(
|
|
num_embeddings=num_embeddings, embedding_dim=embedding_dim
|
|
)
|
|
torch.nn.init.ones_(self.scale.weight)
|
|
torch.nn.init.zeros_(self.shift.weight)
|
|
|
|
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
|
scale = self.scale(cond_embedding_id)
|
|
shift = self.shift(cond_embedding_id)
|
|
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
|
x = x * scale + shift
|
|
return x
|
|
|
|
|
|
class ISTFT(nn.Module):
|
|
"""
|
|
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
|
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
|
See issue: https://github.com/pytorch/pytorch/issues/62323
|
|
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
|
|
|
Args:
|
|
n_fft (int): Size of Fourier transform.
|
|
hop_length (int): The distance between neighboring sliding window frames.
|
|
win_length (int): The size of window frame and STFT filter.
|
|
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
|
"""
|
|
|
|
def __init__(
|
|
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
|
|
):
|
|
super().__init__()
|
|
if padding not in ["center", "same"]:
|
|
raise ValueError("Padding must be 'center' or 'same'.")
|
|
self.padding = padding
|
|
self.n_fft = n_fft
|
|
self.hop_length = hop_length
|
|
self.win_length = win_length
|
|
window = torch.hann_window(win_length)
|
|
self.register_buffer("window", window)
|
|
|
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
|
|
|
Args:
|
|
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
|
N is the number of frequency bins, and T is the number of time frames.
|
|
|
|
Returns:
|
|
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
|
"""
|
|
if self.padding == "center":
|
|
# Fallback to pytorch native implementation
|
|
return torch.istft(
|
|
spec,
|
|
self.n_fft,
|
|
self.hop_length,
|
|
self.win_length,
|
|
self.window,
|
|
center=True,
|
|
)
|
|
elif self.padding == "same":
|
|
pad = (self.win_length - self.hop_length) // 2
|
|
else:
|
|
raise ValueError("Padding must be 'center' or 'same'.")
|
|
|
|
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
|
B, N, T = spec.shape
|
|
|
|
# Inverse FFT
|
|
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
|
ifft = ifft * self.window[None, :, None]
|
|
|
|
# Overlap and Add
|
|
output_size = (T - 1) * self.hop_length + self.win_length
|
|
y = torch.nn.functional.fold(
|
|
ifft,
|
|
output_size=(1, output_size),
|
|
kernel_size=(1, self.win_length),
|
|
stride=(1, self.hop_length),
|
|
)[:, 0, 0, :]
|
|
|
|
# Window envelope
|
|
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
|
window_envelope = torch.nn.functional.fold(
|
|
window_sq,
|
|
output_size=(1, output_size),
|
|
kernel_size=(1, self.win_length),
|
|
stride=(1, self.hop_length),
|
|
).squeeze()
|
|
|
|
# Normalize
|
|
norm_indexes = window_envelope > 1e-11
|
|
y[:, norm_indexes] = y[:, norm_indexes] / window_envelope[norm_indexes]
|
|
|
|
return y
|
|
|
|
|
|
class ConvNeXtBlock(nn.Module):
|
|
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
|
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
intermediate_dim (int): Dimensionality of the intermediate layer.
|
|
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
|
Defaults to None.
|
|
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
|
None means non-conditional LayerNorm. Defaults to None.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
intermediate_dim: int,
|
|
layer_scale_init_value: Optional[float] = None,
|
|
adanorm_num_embeddings: Optional[int] = None,
|
|
):
|
|
super().__init__()
|
|
self.dwconv = nn.Conv1d(
|
|
dim, dim, kernel_size=7, padding=3, groups=dim
|
|
) # depthwise conv
|
|
self.adanorm = adanorm_num_embeddings is not None
|
|
if adanorm_num_embeddings:
|
|
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
|
else:
|
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
self.pwconv1 = nn.Linear(
|
|
dim, intermediate_dim
|
|
) # pointwise/1x1 convs, implemented with linear layers
|
|
self.act = nn.GELU()
|
|
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
|
self.gamma = (
|
|
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
|
if layer_scale_init_value > 0
|
|
else None
|
|
)
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
|
) -> torch.Tensor:
|
|
residual = x
|
|
x = self.dwconv(x)
|
|
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
|
if self.adanorm:
|
|
assert cond_embedding_id is not None
|
|
x = self.norm(x, cond_embedding_id)
|
|
else:
|
|
x = self.norm(x)
|
|
x = self.pwconv1(x)
|
|
x = self.act(x)
|
|
x = self.pwconv2(x)
|
|
if self.gamma is not None:
|
|
x = self.gamma * x
|
|
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
|
|
|
x = residual + x
|
|
return x
|
|
|
|
|
|
class Generator(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
feature_dim: int = 80,
|
|
dim: int = 512,
|
|
n_fft: int = 1024,
|
|
hop_length: int = 256,
|
|
intermediate_dim: int = 1536,
|
|
num_layers: int = 8,
|
|
padding: str = "same",
|
|
layer_scale_init_value: Optional[float] = None,
|
|
adanorm_num_embeddings: Optional[int] = None,
|
|
):
|
|
super(Generator, self).__init__()
|
|
self.feature_dim = feature_dim
|
|
self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3)
|
|
|
|
self.adanorm = adanorm_num_embeddings is not None
|
|
if adanorm_num_embeddings:
|
|
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
|
else:
|
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
|
|
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
|
self.convnext = nn.ModuleList(
|
|
[
|
|
ConvNeXtBlock(
|
|
dim=dim,
|
|
intermediate_dim=intermediate_dim,
|
|
layer_scale_init_value=layer_scale_init_value,
|
|
adanorm_num_embeddings=adanorm_num_embeddings,
|
|
)
|
|
for _ in range(num_layers)
|
|
]
|
|
)
|
|
|
|
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
|
self.apply(self._init_weights)
|
|
|
|
self.out_proj = torch.nn.Linear(dim, n_fft + 2)
|
|
self.istft = ISTFT(
|
|
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
|
)
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
bandwidth_id = kwargs.get("bandwidth_id", None)
|
|
x = self.embed(x)
|
|
if self.adanorm:
|
|
assert bandwidth_id is not None
|
|
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
|
else:
|
|
x = self.norm(x.transpose(1, 2))
|
|
|
|
x = x.transpose(1, 2)
|
|
for conv_block in self.convnext:
|
|
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
|
|
|
x = self.final_layer_norm(x.transpose(1, 2))
|
|
|
|
x = self.out_proj(x).transpose(1, 2)
|
|
mag, p = x.chunk(2, dim=1)
|
|
mag = torch.exp(mag)
|
|
mag = torch.clip(
|
|
mag, max=1e2
|
|
) # safeguard to prevent excessively large magnitudes
|
|
x = torch.cos(p)
|
|
y = torch.sin(p)
|
|
S = mag * (x + 1j * y)
|
|
audio = self.istft(S)
|
|
return audio
|