Add vocos vocoder

This commit is contained in:
pkufool 2024-11-15 15:01:57 +08:00
parent 57451b0382
commit 9c917a0366
10 changed files with 2808 additions and 0 deletions

View File

@ -0,0 +1,127 @@
from typing import Optional
import torch
from torch import nn
from torch.nn.utils import weight_norm
from modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
class Backbone(nn.Module):
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
Returns:
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
and H denotes the model dimension.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class VocosBackbone(Backbone):
"""
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
num_layers (int): Number of ConvNeXtBlock layers.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional model. Defaults to None.
"""
def __init__(
self,
input_channels: int,
dim: int,
intermediate_dim: int,
num_layers: int,
layer_scale_init_value: Optional[float] = None,
adanorm_num_embeddings: Optional[int] = None,
):
super().__init__()
self.input_channels = input_channels
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
self.convnext = nn.ModuleList(
[
ConvNeXtBlock(
dim=dim,
intermediate_dim=intermediate_dim,
layer_scale_init_value=layer_scale_init_value,
adanorm_num_embeddings=adanorm_num_embeddings,
)
for _ in range(num_layers)
]
)
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
bandwidth_id = kwargs.get("bandwidth_id", None)
x = self.embed(x)
if self.adanorm:
assert bandwidth_id is not None
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
else:
x = self.norm(x.transpose(1, 2))
x = x.transpose(1, 2)
for conv_block in self.convnext:
x = conv_block(x, cond_embedding_id=bandwidth_id)
x = self.final_layer_norm(x.transpose(1, 2))
return x
class VocosResNetBackbone(Backbone):
"""
Vocos backbone module built with ResBlocks.
Args:
input_channels (int): Number of input features channels.
dim (int): Hidden dimension of the model.
num_blocks (int): Number of ResBlock1 blocks.
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
"""
def __init__(
self,
input_channels,
dim,
num_blocks,
layer_scale_init_value=None,
):
super().__init__()
self.input_channels = input_channels
self.embed = weight_norm(
nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
)
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
self.resnet = nn.Sequential(
*[
ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
for _ in range(num_blocks)
]
)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.embed(x)
x = self.resnet(x)
x = x.transpose(1, 2)
return x

View File

@ -0,0 +1,296 @@
from typing import List, Optional, Tuple
import torch
from einops import rearrange
from torch import nn
from torch.nn import Conv2d
from torch.nn.utils import weight_norm
from torchaudio.transforms import Spectrogram
class MultiPeriodDiscriminator(nn.Module):
"""
Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
Additionally, it allows incorporating conditional information with a learned embeddings table.
Args:
periods (tuple[int]): Tuple of periods for each discriminator.
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
Defaults to None.
"""
def __init__(
self,
periods: Tuple[int, ...] = (2, 3, 5, 7, 11),
num_embeddings: Optional[int] = None,
):
super().__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods]
)
def forward(
self,
y: torch.Tensor,
y_hat: torch.Tensor,
bandwidth_id: Optional[torch.Tensor] = None,
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorP(nn.Module):
def __init__(
self,
period: int,
in_channels: int = 1,
kernel_size: int = 5,
stride: int = 3,
lrelu_slope: float = 0.1,
num_embeddings: Optional[int] = None,
):
super().__init__()
self.period = period
self.convs = nn.ModuleList(
[
weight_norm(
Conv2d(
in_channels,
32,
(kernel_size, 1),
(stride, 1),
padding=(kernel_size // 2, 0),
)
),
weight_norm(
Conv2d(
32,
128,
(kernel_size, 1),
(stride, 1),
padding=(kernel_size // 2, 0),
)
),
weight_norm(
Conv2d(
128,
512,
(kernel_size, 1),
(stride, 1),
padding=(kernel_size // 2, 0),
)
),
weight_norm(
Conv2d(
512,
1024,
(kernel_size, 1),
(stride, 1),
padding=(kernel_size // 2, 0),
)
),
weight_norm(
Conv2d(
1024,
1024,
(kernel_size, 1),
(1, 1),
padding=(kernel_size // 2, 0),
)
),
]
)
if num_embeddings is not None:
self.emb = torch.nn.Embedding(
num_embeddings=num_embeddings, embedding_dim=1024
)
torch.nn.init.zeros_(self.emb.weight)
self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
self.lrelu_slope = lrelu_slope
def forward(
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
x = x.unsqueeze(1)
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for i, l in enumerate(self.convs):
x = l(x)
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
if i > 0:
fmap.append(x)
if cond_embedding_id is not None:
emb = self.emb(cond_embedding_id)
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
else:
h = 0
x = self.conv_post(x)
fmap.append(x)
x += h
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiResolutionDiscriminator(nn.Module):
def __init__(
self,
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
num_embeddings: Optional[int] = None,
):
"""
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
Additionally, it allows incorporating conditional information with a learned embeddings table.
Args:
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
Defaults to None.
"""
super().__init__()
self.discriminators = nn.ModuleList(
[
DiscriminatorR(window_length=w, num_embeddings=num_embeddings)
for w in fft_sizes
]
)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorR(nn.Module):
def __init__(
self,
window_length: int,
num_embeddings: Optional[int] = None,
channels: int = 32,
hop_factor: float = 0.25,
bands: Tuple[Tuple[float, float], ...] = (
(0.0, 0.1),
(0.1, 0.25),
(0.25, 0.5),
(0.5, 0.75),
(0.75, 1.0),
),
):
super().__init__()
self.window_length = window_length
self.hop_factor = hop_factor
self.spec_fn = Spectrogram(
n_fft=window_length,
hop_length=int(window_length * hop_factor),
win_length=window_length,
power=None,
)
n_fft = window_length // 2 + 1
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
self.bands = bands
convs = lambda: nn.ModuleList(
[
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
if num_embeddings is not None:
self.emb = torch.nn.Embedding(
num_embeddings=num_embeddings, embedding_dim=channels
)
torch.nn.init.zeros_(self.emb.weight)
self.conv_post = weight_norm(
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
)
def spectrogram(self, x):
# Remove DC offset
x = x - x.mean(dim=-1, keepdims=True)
# Peak normalize the volume of input audio
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
x = self.spec_fn(x)
x = torch.view_as_real(x)
x = rearrange(x, "b f t c -> b c t f")
# Split into bands
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
return x_bands
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
x_bands = self.spectrogram(x)
fmap = []
x = []
for band, stack in zip(x_bands, self.band_convs):
for i, layer in enumerate(stack):
band = layer(band)
band = torch.nn.functional.leaky_relu(band, 0.1)
if i > 0:
fmap.append(band)
x.append(band)
x = torch.cat(x, dim=-1)
if cond_embedding_id is not None:
emb = self.emb(cond_embedding_id)
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
else:
h = 0
x = self.conv_post(x)
fmap.append(x)
x += h
return x, fmap

View File

@ -0,0 +1,178 @@
from typing import Optional
import torch
from torch import nn
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
from spectral_ops import IMDCT, ISTFT
from modules import symexp
class FourierHead(nn.Module):
"""Base class for inverse fourier modules."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class ISTFTHead(FourierHead):
"""
ISTFT Head module for predicting STFT complex coefficients.
Args:
dim (int): Hidden dimension of the model.
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames, which should align with
the resolution of the input features.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
super().__init__()
out_dim = n_fft + 2
self.out = torch.nn.Linear(dim, out_dim)
self.istft = ISTFT(
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ISTFTHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x).transpose(1, 2)
mag, p = x.chunk(2, dim=1)
mag = torch.exp(mag)
mag = torch.clip(
mag, max=1e2
) # safeguard to prevent excessively large magnitudes
# wrapping happens here. These two lines produce real and imaginary value
x = torch.cos(p)
y = torch.sin(p)
# recalculating phase here does not produce anything new
# only costs time
# phase = torch.atan2(y, x)
# S = mag * torch.exp(phase * 1j)
# better directly produce the complex value
S = mag * (x + 1j * y)
audio = self.istft(S)
return audio
class IMDCTSymExpHead(FourierHead):
"""
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
based on perceptual scaling. Defaults to None.
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def __init__(
self,
dim: int,
mdct_frame_len: int,
padding: str = "same",
sample_rate: Optional[int] = None,
clip_audio: bool = False,
):
super().__init__()
out_dim = mdct_frame_len // 2
self.out = nn.Linear(dim, out_dim)
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
self.clip_audio = clip_audio
if sample_rate is not None:
# optionally init the last layer following mel-scale
m_max = _hz_to_mel(sample_rate // 2)
m_pts = torch.linspace(0, m_max, out_dim)
f_pts = _mel_to_hz(m_pts)
scale = 1 - (f_pts / f_pts.max())
with torch.no_grad():
self.out.weight.mul_(scale.view(-1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the IMDCTSymExpHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x)
x = symexp(x)
x = torch.clip(
x, min=-1e2, max=1e2
) # safeguard to prevent excessively large magnitudes
audio = self.imdct(x)
if self.clip_audio:
audio = torch.clip(x, min=-1.0, max=1.0)
return audio
class IMDCTCosHead(FourierHead):
"""
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
Args:
dim (int): Hidden dimension of the model.
mdct_frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
"""
def __init__(
self,
dim: int,
mdct_frame_len: int,
padding: str = "same",
clip_audio: bool = False,
):
super().__init__()
self.clip_audio = clip_audio
self.out = nn.Linear(dim, mdct_frame_len)
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the IMDCTCosHead module.
Args:
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
L is the sequence length, and H denotes the model dimension.
Returns:
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
"""
x = self.out(x)
m, p = x.chunk(2, dim=2)
m = torch.exp(m).clip(
max=1e2
) # safeguard to prevent excessively large magnitudes
audio = self.imdct(m * torch.cos(p))
if self.clip_audio:
audio = torch.clip(x, min=-1.0, max=1.0)
return audio

View File

@ -0,0 +1,133 @@
from typing import List, Tuple
import torch
import torchaudio
from torch import nn
from modules import safe_log
class MelSpecReconstructionLoss(nn.Module):
"""
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
"""
def __init__(
self,
sample_rate: int = 24000,
n_fft: int = 1024,
hop_length: int = 256,
n_mels: int = 100,
):
super().__init__()
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
center=True,
power=1,
)
def forward(self, y_hat, y) -> torch.Tensor:
"""
Args:
y_hat (Tensor): Predicted audio waveform.
y (Tensor): Ground truth audio waveform.
Returns:
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
"""
mel_hat = safe_log(self.mel_spec(y_hat))
mel = safe_log(self.mel_spec(y))
loss = torch.nn.functional.l1_loss(mel, mel_hat)
return loss
class GeneratorLoss(nn.Module):
"""
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
"""
def forward(
self, disc_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
disc_outputs (List[Tensor]): List of discriminator outputs.
Returns:
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
the sub-discriminators
"""
loss = torch.zeros(
1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype
)
gen_losses = []
for dg in disc_outputs:
l = torch.mean(torch.clamp(1 - dg, min=0))
gen_losses.append(l)
loss += l
return loss, gen_losses
class DiscriminatorLoss(nn.Module):
"""
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
"""
def forward(
self,
disc_real_outputs: List[torch.Tensor],
disc_generated_outputs: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
"""
Args:
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
Returns:
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
the sub-discriminators for real outputs, and a list of
loss values for generated outputs.
"""
loss = torch.zeros(
1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype
)
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
class FeatureMatchingLoss(nn.Module):
"""
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
"""
def forward(
self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
) -> torch.Tensor:
"""
Args:
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
Returns:
Tensor: The calculated feature matching loss.
"""
loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss

View File

@ -0,0 +1,51 @@
import logging
import torch
from backbone import VocosBackbone
from heads import ISTFTHead
from discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
from loss import (
DiscriminatorLoss,
GeneratorLoss,
FeatureMatchingLoss,
MelSpecReconstructionLoss,
)
class Vocos(torch.nn.Module):
def __init__(
self,
dim: int = 512,
n_fft: int = 1024,
hop_length: int = 256,
feature_dim: int = 80,
intermediate_dim: int = 1536,
num_layers: int = 8,
padding: str = "same",
sample_rate: int = 22050,
):
super(Vocos, self).__init__()
self.backbone = VocosBackbone(
input_channels=feature_dim,
dim=dim,
intermediate_dim=intermediate_dim,
num_layers=num_layers,
)
self.head = ISTFTHead(
dim=dim,
n_fft=n_fft,
hop_length=hop_length,
padding=padding,
)
self.mpd = MultiPeriodDiscriminator()
self.mrd = MultiResolutionDiscriminator()
self.disc_loss = DiscriminatorLoss()
self.gen_loss = GeneratorLoss()
self.feat_matching_loss = FeatureMatchingLoss()
self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
def forward(self, features: torch.Tensor):
x = self.backbone(features)
audio_output = self.head(x)
return audio_output

View File

@ -0,0 +1,213 @@
from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn.utils import weight_norm, remove_weight_norm
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
"""
def __init__(
self,
dim: int,
intermediate_dim: int,
layer_scale_init_value: float,
adanorm_num_embeddings: Optional[int] = None,
):
super().__init__()
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.adanorm = adanorm_num_embeddings is not None
if adanorm_num_embeddings:
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
else:
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
residual = x
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
if self.adanorm:
assert cond_embedding_id is not None
x = self.norm(x, cond_embedding_id)
else:
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class AdaLayerNorm(nn.Module):
"""
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
Args:
num_embeddings (int): Number of embeddings.
embedding_dim (int): Dimension of the embeddings.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.dim = embedding_dim
self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
torch.nn.init.ones_(self.scale.weight)
torch.nn.init.zeros_(self.shift.weight)
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
scale = self.scale(cond_embedding_id)
shift = self.shift(cond_embedding_id)
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
x = x * scale + shift
return x
class ResBlock1(nn.Module):
"""
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
but without upsampling layers.
Args:
dim (int): Number of input channels.
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
Defaults to (1, 3, 5).
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
Defaults to 0.1.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
"""
def __init__(
self,
dim: int,
kernel_size: int = 3,
dilation: Tuple[int, int, int] = (1, 3, 5),
lrelu_slope: float = 0.1,
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
self.lrelu_slope = lrelu_slope
self.convs1 = nn.ModuleList(
[
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[0],
padding=self.get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[1],
padding=self.get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
nn.Conv1d(
dim,
dim,
kernel_size,
1,
dilation=dilation[2],
padding=self.get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs2 = nn.ModuleList(
[
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
]
)
self.gamma = nn.ParameterList(
[
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
if layer_scale_init_value is not None
else None,
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
if layer_scale_init_value is not None
else None,
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
if layer_scale_init_value is not None
else None,
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
xt = c1(xt)
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
xt = c2(xt)
if gamma is not None:
xt = gamma * xt
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
@staticmethod
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
"""
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
Args:
x (Tensor): Input tensor.
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
Returns:
Tensor: Element-wise logarithm of the input tensor with clipping applied.
"""
return torch.log(torch.clip(x, min=clip_val))
def symlog(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(x.abs())
def symexp(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * (torch.exp(x.abs()) - 1)

View File

@ -0,0 +1,230 @@
import numpy as np
import scipy
import torch
from torch import nn, view_as_real, view_as_complex
class ISTFT(nn.Module):
"""
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
See issue: https://github.com/pytorch/pytorch/issues/62323
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
The NOLA constraint is met as we trim padded samples anyway.
Args:
n_fft (int): Size of Fourier transform.
hop_length (int): The distance between neighboring sliding window frames.
win_length (int): The size of window frame and STFT filter.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(
self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
window = torch.hann_window(win_length)
self.register_buffer("window", window)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
Args:
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
N is the number of frequency bins, and T is the number of time frames.
Returns:
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
"""
if self.padding == "center":
# Fallback to pytorch native implementation
return torch.istft(
spec,
self.n_fft,
self.hop_length,
self.win_length,
self.window,
center=True,
)
elif self.padding == "same":
# return torch.istft(
# spec,
# self.n_fft,
# self.hop_length,
# self.win_length,
# self.window,
# center=False,
# )
pad = (self.win_length - self.hop_length) // 2
else:
raise ValueError("Padding must be 'center' or 'same'.")
assert spec.dim() == 3, "Expected a 3D tensor as input"
B, N, T = spec.shape
# Inverse FFT
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
ifft = ifft * self.window[None, :, None]
# Overlap and Add
output_size = (T - 1) * self.hop_length + self.win_length
y = torch.nn.functional.fold(
ifft,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
)[:, 0, 0, :]
# Window envelope
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
window_sq,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
).squeeze()
# Normalize
norm_indexes = window_envelope > 1e-11
y[:, norm_indexes] = y[:, norm_indexes] / window_envelope[norm_indexes]
# assert (window_envelope > 1e-11).all()
# y = y / window_envelope
return y
class MDCT(nn.Module):
"""
Modified Discrete Cosine Transform (MDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, frame_len: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.frame_len = frame_len
N = frame_len // 2
n0 = (N + 1) / 2
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
self.register_buffer("window", window)
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
# view_as_real: NCCL Backend does not support ComplexFloat data type
# https://github.com/pytorch/pytorch/issues/71613
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
def forward(self, audio: torch.Tensor) -> torch.Tensor:
"""
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
Args:
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
and T is the length of the audio.
Returns:
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
and N is the number of frequency bins.
"""
if self.padding == "center":
audio = torch.nn.functional.pad(
audio, (self.frame_len // 2, self.frame_len // 2)
)
elif self.padding == "same":
# hop_length is 1/2 frame_len
audio = torch.nn.functional.pad(
audio, (self.frame_len // 4, self.frame_len // 4)
)
else:
raise ValueError("Padding must be 'center' or 'same'.")
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
N = self.frame_len // 2
x = x * self.window.expand(x.shape)
X = torch.fft.fft(
x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
)[..., :N]
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
return torch.real(res) * np.sqrt(2)
class IMDCT(nn.Module):
"""
Inverse Modified Discrete Cosine Transform (IMDCT) module.
Args:
frame_len (int): Length of the MDCT frame.
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
"""
def __init__(self, frame_len: int, padding: str = "same"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.frame_len = frame_len
N = frame_len // 2
n0 = (N + 1) / 2
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
self.register_buffer("window", window)
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
Args:
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
L is the number of frames, and N is the number of frequency bins.
Returns:
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
"""
B, L, N = X.shape
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
Y[..., :N] = X
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
y = torch.fft.ifft(
Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
)
y = (
torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
* np.sqrt(N)
* np.sqrt(2)
)
result = y * self.window.expand(y.shape)
output_size = (1, (L + 1) * N)
audio = torch.nn.functional.fold(
result.transpose(1, 2),
output_size=output_size,
kernel_size=(1, self.frame_len),
stride=(1, self.frame_len // 2),
)[:, 0, 0, :]
if self.padding == "center":
pad = self.frame_len // 2
elif self.padding == "same":
pad = self.frame_len // 4
else:
raise ValueError("Padding must be 'center' or 'same'.")
audio = audio[:, pad:-pad]
return audio

1003
egs/ljspeech/TTS/vocos/train.py Executable file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,372 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LJSpeechTtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--sampling-rate",
type=int,
default=22050,
help="The sampleing rate of ljspeech dataset",
)
group.add_argument(
"--frame-shift",
type=int,
default=256,
help="Frame shift.",
)
group.add_argument(
"--frame-length",
type=int,
default=1024,
help="Frame shift.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
group.add_argument(
"--use-fft-mag",
type=str2bool,
default=True,
help="Whether to use magnitude of fbank, false to use power energy.",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = self.args.sampling_rate
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=self.args.frame_length / sampling_rate, # (in second),
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
use_fft_mag=self.args.use_fft_mag,
)
train = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = self.args.sampling_rate
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=self.args.frame_length / sampling_rate, # (in second),
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
use_fft_mag=self.args.use_fft_mag,
)
validate = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = self.args.sampling_rate
config = FbankConfig(
sampling_rate=sampling_rate,
frame_length=self.args.frame_length / sampling_rate, # (in second),
frame_shift=self.args.frame_shift / sampling_rate, # (in second)
use_fft_mag=self.args.use_fft_mag,
)
test = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=OnTheFlyFeatures(Fbank(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
)
@lru_cache()
def train_cuts_finetune(self) -> CutSet:
logging.info("About to get train cuts finetune")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_train_finetune.jsonl.gz"
)
@lru_cache()
def valid_cuts_finetune(self) -> CutSet:
logging.info("About to get validation cuts finetune")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_valid_finetune.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1,205 @@
import glob
import os
import logging
import matplotlib
import math
import torch
import torch.nn as nn
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from torch.nn.utils import weight_norm
from torch.optim.lr_scheduler import LRScheduler
from torch.optim import Optimizer
from torch.cuda.amp import GradScaler
from lhotse.dataset.sampling.base import CutSampler
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
matplotlib.use("Agg")
import matplotlib.pylab as plt
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def load_checkpoint(
filename: Path,
model: nn.Module,
model_avg: Optional[nn.Module] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRScheduler] = None,
scheduler_d: Optional[LRScheduler] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
strict: bool = False,
) -> Dict[str, Any]:
logging.info(f"Loading checkpoint from {filename}")
checkpoint = torch.load(filename, map_location="cpu")
if next(iter(checkpoint["model"])).startswith("module."):
logging.info("Loading checkpoint saved by DDP")
dst_state_dict = model.state_dict()
src_state_dict = checkpoint["model"]
for key in dst_state_dict.keys():
src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=strict)
else:
model.load_state_dict(checkpoint["model"], strict=strict)
checkpoint.pop("model")
if model_avg is not None and "model_avg" in checkpoint:
logging.info("Loading averaged model")
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
checkpoint.pop("model_avg")
def load(name, obj):
s = checkpoint.get(name, None)
if obj and s:
obj.load_state_dict(s)
checkpoint.pop(name)
load("optimizer_g", optimizer_g)
load("optimizer_d", optimizer_d)
load("scheduler_g", scheduler_g)
load("scheduler_d", scheduler_d)
load("grad_scaler", scaler)
load("sampler", sampler)
return checkpoint
def save_checkpoint(
filename: Path,
model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None,
params: Optional[Dict[str, Any]] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRScheduler] = None,
scheduler_d: Optional[LRScheduler] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
"""Save training information to a file.
Args:
filename:
The checkpoint filename.
model:
The model to be saved. We only save its `state_dict()`.
model_avg:
The stored model averaged from the start of training.
params:
User defined parameters, e.g., epoch, loss.
optimizer:
The optimizer to be saved. We only save its `state_dict()`.
scheduler:
The scheduler to be saved. We only save its `state_dict()`.
scalar:
The GradScaler to be saved. We only save its `state_dict()`.
rank:
Used in DDP. We save checkpoint only for the node whose rank is 0.
Returns:
Return None.
"""
if rank != 0:
return
logging.info(f"Saving checkpoint to {filename}")
if isinstance(model, DDP):
model = model.module
checkpoint = {
"model": model.state_dict(),
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
"grad_scaler": scaler.state_dict() if scaler is not None else None,
"sampler": sampler.state_dict() if sampler is not None else None,
}
if model_avg is not None:
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
if params:
for k, v in params.items():
assert k not in checkpoint
checkpoint[k] = v
torch.save(checkpoint, filename)
def _get_cosine_schedule_with_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float,
min_lr_rate: float = 0.0,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
factor = factor * (1 - min_lr_rate) + min_lr_rate
return max(0, factor)
def get_cosine_schedule_with_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_cosine_schedule_with_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)