From 604ab6f6b38d6be2fe6d9a265f9e9658e871d008 Mon Sep 17 00:00:00 2001 From: yuekaiz Date: Sun, 22 Dec 2024 15:22:26 +0800 Subject: [PATCH] add f5 --- egs/ljspeech/TTS/matcha/fbank.py | 5 +- .../TTS/f5-tts/model/__init__.py | 7 + egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py | 326 +++++ .../TTS/f5-tts/model/modules.py | 728 ++++++++++ egs/wenetspeech4tts/TTS/f5-tts/train.py | 1192 +++++++++++++++++ .../TTS/f5-tts/tts_datamodule.py | 346 +++++ egs/wenetspeech4tts/TTS/local/audio.py | 122 ++ .../TTS/local/compute_mel_feat.py | 218 +++ egs/wenetspeech4tts/TTS/local/fbank.py | 1 + egs/wenetspeech4tts/TTS/prepare.sh | 62 +- 10 files changed, 3002 insertions(+), 5 deletions(-) create mode 100644 egs/wenetspeech4tts/TTS/f5-tts/model/__init__.py create mode 100644 egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py create mode 100644 egs/wenetspeech4tts/TTS/f5-tts/model/modules.py create mode 100755 egs/wenetspeech4tts/TTS/f5-tts/train.py create mode 100644 egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py create mode 100644 egs/wenetspeech4tts/TTS/local/audio.py create mode 100755 egs/wenetspeech4tts/TTS/local/compute_mel_feat.py create mode 120000 egs/wenetspeech4tts/TTS/local/fbank.py diff --git a/egs/ljspeech/TTS/matcha/fbank.py b/egs/ljspeech/TTS/matcha/fbank.py index d729fa425..cc94a301f 100644 --- a/egs/ljspeech/TTS/matcha/fbank.py +++ b/egs/ljspeech/TTS/matcha/fbank.py @@ -17,6 +17,7 @@ class MatchaFbankConfig: win_length: int f_min: float f_max: float + device: str = "cuda" @register_extractor @@ -46,7 +47,7 @@ class MatchaFbank(FeatureExtractor): f"Mismatched sampling rate: extractor expects {expected_sr}, " f"got {sampling_rate}" ) - samples = torch.from_numpy(samples) + samples = torch.from_numpy(samples).to(self.device) assert samples.ndim == 2, samples.shape assert samples.shape[0] == 1, samples.shape @@ -81,7 +82,7 @@ class MatchaFbank(FeatureExtractor): mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" ).squeeze(0) - return mel.numpy() + return mel.cpu().numpy() @property def frame_shift(self) -> Seconds: diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/__init__.py b/egs/wenetspeech4tts/TTS/f5-tts/model/__init__.py new file mode 100644 index 000000000..878b5c171 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/__init__.py @@ -0,0 +1,7 @@ +# from f5_tts.model.backbones.dit import DiT +# from f5_tts.model.backbones.mmdit import MMDiT +# from f5_tts.model.backbones.unett import UNetT +# from f5_tts.model.cfm import CFM +# from f5_tts.model.trainer import Trainer + +# __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"] diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py b/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py new file mode 100644 index 000000000..349c7220e --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py @@ -0,0 +1,326 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +from random import random +from typing import Callable + +import torch +import torch.nn.functional as F +from model.modules import MelSpec +from model.utils import ( + default, + exists, + lens_to_mask, + list_str_to_idx, + list_str_to_tensor, + mask_from_frac_lengths, +) +from torch import nn +from torch.nn.utils.rnn import pad_sequence +from torchdiffeq import odeint + + +class CFM(nn.Module): + def __init__( + self, + transformer: nn.Module, + sigma=0.0, + odeint_kwargs: dict = dict( + # atol = 1e-5, + # rtol = 1e-5, + method="euler" # 'midpoint' + ), + audio_drop_prob=0.3, + cond_drop_prob=0.2, + num_channels=None, + mel_spec_module: nn.Module | None = None, + mel_spec_kwargs: dict = dict(), + frac_lengths_mask: tuple[float, float] = (0.7, 1.0), + vocab_char_map: dict[str:int] | None = None, + ): + super().__init__() + + self.frac_lengths_mask = frac_lengths_mask + + # mel spec + self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) + num_channels = default(num_channels, self.mel_spec.n_mel_channels) + self.num_channels = num_channels + + # classifier-free guidance + self.audio_drop_prob = audio_drop_prob + self.cond_drop_prob = cond_drop_prob + + # transformer + self.transformer = transformer + dim = transformer.dim + self.dim = dim + + # conditional flow related + self.sigma = sigma + + # sampling related + self.odeint_kwargs = odeint_kwargs + + # vocab map for tokenization + self.vocab_char_map = vocab_char_map + + @property + def device(self): + return next(self.parameters()).device + + @torch.no_grad() + def sample( + self, + cond: float["b n d"] | float["b nw"], # noqa: F722 + text: int["b nt"] | list[str], # noqa: F722 + duration: int | int["b"], # noqa: F821 + *, + lens: int["b"] | None = None, # noqa: F821 + steps=32, + cfg_strength=1.0, + sway_sampling_coef=None, + seed: int | None = None, + max_duration=4096, + vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + no_ref_audio=False, + duplicate_test=False, + t_inter=0.1, + edit_mask=None, + ): + self.eval() + # raw wave + + if cond.ndim == 2: + cond = self.mel_spec(cond) + cond = cond.permute(0, 2, 1) + assert cond.shape[-1] == self.num_channels + + cond = cond.to(next(self.parameters()).dtype) + + batch, cond_seq_len, device = *cond.shape[:2], cond.device + if not exists(lens): + lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long) + + # text + + if isinstance(text, list): + if exists(self.vocab_char_map): + text = list_str_to_idx(text, self.vocab_char_map).to(device) + else: + text = list_str_to_tensor(text).to(device) + assert text.shape[0] == batch + + if exists(text): + text_lens = (text != -1).sum(dim=-1) + lens = torch.maximum( + text_lens, lens + ) # make sure lengths are at least those of the text characters + + # duration + + cond_mask = lens_to_mask(lens) + if edit_mask is not None: + cond_mask = cond_mask & edit_mask + + if isinstance(duration, int): + duration = torch.full((batch,), duration, device=device, dtype=torch.long) + + duration = torch.maximum( + lens + 1, duration + ) # just add one token so something is generated + duration = duration.clamp(max=max_duration) + max_duration = duration.amax() + + # duplicate test corner for inner time step oberservation + if duplicate_test: + test_cond = F.pad( + cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0 + ) + + cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0) + cond_mask = F.pad( + cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False + ) + cond_mask = cond_mask.unsqueeze(-1) + step_cond = torch.where( + cond_mask, cond, torch.zeros_like(cond) + ) # allow direct control (cut cond audio) with lens passed in + + if batch > 1: + mask = lens_to_mask(duration) + else: # save memory and speed up, as single inference need no mask currently + mask = None + + # test for no ref audio + if no_ref_audio: + cond = torch.zeros_like(cond) + + # neural ode + + def fn(t, x): + # at each step, conditioning is fixed + # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) + + # predict flow + pred = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + drop_audio_cond=False, + drop_text=False, + ) + if cfg_strength < 1e-5: + return pred + + null_pred = self.transformer( + x=x, + cond=step_cond, + text=text, + time=t, + mask=mask, + drop_audio_cond=True, + drop_text=True, + ) + return pred + (pred - null_pred) * cfg_strength + + # noise input + # to make sure batch inference result is same with different batch size, and for sure single inference + # still some difference maybe due to convolutional layers + y0 = [] + for dur in duration: + if exists(seed): + torch.manual_seed(seed) + y0.append( + torch.randn( + dur, self.num_channels, device=self.device, dtype=step_cond.dtype + ) + ) + y0 = pad_sequence(y0, padding_value=0, batch_first=True) + + t_start = 0 + + # duplicate test corner for inner time step oberservation + if duplicate_test: + t_start = t_inter + y0 = (1 - t_start) * y0 + t_start * test_cond + steps = int(steps * (1 - t_start)) + + t = torch.linspace( + t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype + ) + if sway_sampling_coef is not None: + t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) + + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) + + sampled = trajectory[-1] + out = sampled + out = torch.where(cond_mask, cond, out) + + if exists(vocoder): + out = out.permute(0, 2, 1) + out = vocoder(out) + + return out, trajectory + + def forward( + self, + inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722 + text: int["b nt"] | list[str], # noqa: F722 + *, + lens: int["b"] | None = None, # noqa: F821 + noise_scheduler: str | None = None, + ): + # handle raw wave + if inp.ndim == 2: + inp = self.mel_spec(inp) + inp = inp.permute(0, 2, 1) + assert inp.shape[-1] == self.num_channels + + batch, seq_len, dtype, device, _σ1 = ( + *inp.shape[:2], + inp.dtype, + self.device, + self.sigma, + ) + + # handle text as string + if isinstance(text, list): + if exists(self.vocab_char_map): + text = list_str_to_idx(text, self.vocab_char_map).to(device) + else: + text = list_str_to_tensor(text).to(device) + assert text.shape[0] == batch + + # lens and mask + if not exists(lens): + lens = torch.full((batch,), seq_len, device=device) + + mask = lens_to_mask( + lens, length=seq_len + ) # useless here, as collate_fn will pad to max length in batch + + # get a random span to mask out for training conditionally + frac_lengths = ( + torch.zeros((batch,), device=self.device) + .float() + .uniform_(*self.frac_lengths_mask) + ) + rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) + + if exists(mask): + rand_span_mask &= mask + + # mel is x1 + x1 = inp + + # x0 is gaussian noise + x0 = torch.randn_like(x1) + + # time step + time = torch.rand((batch,), dtype=dtype, device=self.device) + # TODO. noise_scheduler + + # sample xt (φ_t(x) in the paper) + t = time.unsqueeze(-1).unsqueeze(-1) + φ = (1 - t) * x0 + t * x1 + flow = x1 - x0 + + # only predict what is within the random mask span for infilling + cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1) + + # transformer and cfg training with a drop rate + drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper + if random() < self.cond_drop_prob: # p_uncond in voicebox paper + drop_audio_cond = True + drop_text = True + else: + drop_text = False + + # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here + # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences + pred = self.transformer( + x=φ, + cond=cond, + text=text, + time=time, + drop_audio_cond=drop_audio_cond, + drop_text=drop_text, + ) + + # flow matching loss + loss = F.mse_loss(pred, flow, reduction="none") + loss = loss[rand_span_mask] + + return loss.mean(), cond, pred diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/modules.py b/egs/wenetspeech4tts/TTS/f5-tts/model/modules.py new file mode 100644 index 000000000..05299d419 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/modules.py @@ -0,0 +1,728 @@ +""" +ein notation: +b - batch +n - sequence +nt - text sequence +nw - raw wave length +d - dimension +""" + +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +import torchaudio +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from x_transformers.x_transformers import apply_rotary_pos_emb + +# raw wav to mel spec + + +mel_basis_cache = {} +hann_window_cache = {} + + +def get_bigvgan_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, + fmin=0, + fmax=None, + center=False, +): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main + device = waveform.device + key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}" + + if key not in mel_basis_cache: + mel = librosa_mel_fn( + sr=target_sample_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=fmin, + fmax=fmax, + ) + mel_basis_cache[key] = ( + torch.from_numpy(mel).float().to(device) + ) # TODO: why they need .float()? + hann_window_cache[key] = torch.hann_window(win_length).to(device) + + mel_basis = mel_basis_cache[key] + hann_window = hann_window_cache[key] + + padding = (n_fft - hop_length) // 2 + waveform = torch.nn.functional.pad( + waveform.unsqueeze(1), (padding, padding), mode="reflect" + ).squeeze(1) + + spec = torch.stft( + waveform, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5)) + + return mel_spec + + +def get_vocos_mel_spectrogram( + waveform, + n_fft=1024, + n_mel_channels=100, + target_sample_rate=24000, + hop_length=256, + win_length=1024, +): + mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + n_mels=n_mel_channels, + power=1, + center=True, + normalized=False, + norm=None, + ).to(waveform.device) + if len(waveform.shape) == 3: + waveform = waveform.squeeze(1) # 'b 1 nw -> b nw' + + assert len(waveform.shape) == 2 + + mel = mel_stft(waveform) + mel = mel.clamp(min=1e-5).log() + return mel + + +class MelSpec(nn.Module): + def __init__( + self, + n_fft=1024, + hop_length=256, + win_length=1024, + n_mel_channels=100, + target_sample_rate=24_000, + mel_spec_type="vocos", + ): + super().__init__() + assert mel_spec_type in ["vocos", "bigvgan"], print( + "We only support two extract mel backend: vocos or bigvgan" + ) + + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.target_sample_rate = target_sample_rate + + if mel_spec_type == "vocos": + self.extractor = get_vocos_mel_spectrogram + elif mel_spec_type == "bigvgan": + self.extractor = get_bigvgan_mel_spectrogram + + self.register_buffer("dummy", torch.tensor(0), persistent=False) + + def forward(self, wav): + if self.dummy.device != wav.device: + self.to(wav.device) + + mel = self.extractor( + waveform=wav, + n_fft=self.n_fft, + n_mel_channels=self.n_mel_channels, + target_sample_rate=self.target_sample_rate, + hop_length=self.hop_length, + win_length=self.win_length, + ) + + return mel + + +# sinusoidal position embedding + + +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +# convolutional position embedding + + +class ConvPositionEmbedding(nn.Module): + def __init__(self, dim, kernel_size=31, groups=16): + super().__init__() + assert kernel_size % 2 != 0 + self.conv1d = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2), + nn.Mish(), + ) + + def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722 + if mask is not None: + mask = mask[..., None] + x = x.masked_fill(~mask, 0.0) + + x = x.permute(0, 2, 1) + x = self.conv1d(x) + out = x.permute(0, 2, 1) + + if mask is not None: + out = out.masked_fill(~mask, 0.0) + + return out + + +# rotary positional embedding related + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0 +): + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py + theta *= theta_rescale_factor ** (dim / (dim - 2)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cos = torch.cos(freqs) # real part + freqs_sin = torch.sin(freqs) # imaginary part + return torch.cat([freqs_cos, freqs_sin], dim=-1) + + +def get_pos_embed_indices(start, length, max_pos, scale=1.0): + # length = length if isinstance(length, int) else length.max() + scale = scale * torch.ones_like( + start, dtype=torch.float32 + ) # in case scale is a scalar + pos = ( + start.unsqueeze(1) + + ( + torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) + * scale.unsqueeze(1) + ).long() + ) + # avoid extra long error. + pos = torch.where(pos < max_pos, pos, max_pos - 1) + return pos + + +# Global Response Normalization layer (Instance Normalization ?) + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=1, keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py +# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108 + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = x.transpose(1, 2) # b n d -> b d n + x = self.dwconv(x) + x = x.transpose(1, 2) # b d n -> b n d + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + return residual + x + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation + + +class AdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk( + emb, 6, dim=1 + ) + + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation + + +class AdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +# FeedForward + + +class FeedForward(nn.Module): + def __init__( + self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none" + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + activation = nn.GELU(approximate=approximate) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation) + self.ff = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.ff(x) + + +# Attention with possible joint part +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class Attention(nn.Module): + def __init__( + self, + processor: JointAttnProcessor | AttnProcessor, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + context_dim: Optional[int] = None, # if not None -> joint attention + context_pre_only=None, + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + self.processor = processor + + self.dim = dim + self.heads = heads + self.inner_dim = dim_head * heads + self.dropout = dropout + + self.context_dim = context_dim + self.context_pre_only = context_pre_only + + self.to_q = nn.Linear(dim, self.inner_dim) + self.to_k = nn.Linear(dim, self.inner_dim) + self.to_v = nn.Linear(dim, self.inner_dim) + + if self.context_dim is not None: + self.to_k_c = nn.Linear(context_dim, self.inner_dim) + self.to_v_c = nn.Linear(context_dim, self.inner_dim) + if self.context_pre_only is not None: + self.to_q_c = nn.Linear(context_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, dim)) + self.to_out.append(nn.Dropout(dropout)) + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_out_c = nn.Linear(self.inner_dim, dim) + + def forward( + self, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b n d"] = None, # context c # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.Tensor: + if c is not None: + return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope) + else: + return self.processor(self, x, mask=mask, rope=rope) + + +# Attention processor + + +class AttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding + ) -> torch.FloatTensor: + batch_size = x.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # apply rotary position embedding + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) + if xpos_scale is not None + else (1.0, 1.0) + ) + + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = mask + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) + else: + attn_mask = None + + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + + return x + + +# Joint Attention processor for MM-DiT +# modified from diffusers/src/diffusers/models/attention_processor.py + + +class JointAttnProcessor: + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + x: float["b n d"], # noised input x # noqa: F722 + c: float["b nt d"] = None, # context c, here text # noqa: F722 + mask: bool["b n"] | None = None, # noqa: F722 + rope=None, # rotary position embedding for x + c_rope=None, # rotary position embedding for c + ) -> torch.FloatTensor: + residual = x + + batch_size = c.shape[0] + + # `sample` projections. + query = attn.to_q(x) + key = attn.to_k(x) + value = attn.to_v(x) + + # `context` projections. + c_query = attn.to_q_c(c) + c_key = attn.to_k_c(c) + c_value = attn.to_v_c(c) + + # apply rope for context and noised input independently + if rope is not None: + freqs, xpos_scale = rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) + if xpos_scale is not None + else (1.0, 1.0) + ) + query = apply_rotary_pos_emb(query, freqs, q_xpos_scale) + key = apply_rotary_pos_emb(key, freqs, k_xpos_scale) + if c_rope is not None: + freqs, xpos_scale = c_rope + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) + if xpos_scale is not None + else (1.0, 1.0) + ) + c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale) + c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale) + + # attention + query = torch.cat([query, c_query], dim=1) + key = torch.cat([key, c_key], dim=1) + value = torch.cat([value, c_value], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + if mask is not None: + attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text) + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n' + attn_mask = attn_mask.expand( + batch_size, attn.heads, query.shape[-2], key.shape[-2] + ) + else: + attn_mask = None + + x = F.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False + ) + x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + x = x.to(query.dtype) + + # Split the attention outputs. + x, c = ( + x[:, : residual.shape[1]], + x[:, residual.shape[1] :], + ) + + # linear proj + x = attn.to_out[0](x) + # dropout + x = attn.to_out[1](x) + if not attn.context_pre_only: + c = attn.to_out_c(c) + + if mask is not None: + mask = mask.unsqueeze(-1) + x = x.masked_fill(~mask, 0.0) + # c = c.masked_fill(~mask, 0.) # no mask for c (text) + + return x, c + + +# DiT Block + + +class DiTBlock(nn.Module): + def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1): + super().__init__() + + self.attn_norm = AdaLayerNormZero(dim) + self.attn = Attention( + processor=AttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + ) + + self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) + + def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + + # attention + attn_output = self.attn(x=norm, mask=mask, rope=rope) + + # process attention output for input x + x = x + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + x = x + gate_mlp.unsqueeze(1) * ff_output + + return x + + +# MMDiT Block https://arxiv.org/abs/2403.03206 + + +class MMDiTBlock(nn.Module): + r""" + modified from diffusers/src/diffusers/models/attention.py + + notes. + _c: context related. text, cond, etc. (left part in sd3 fig2.b) + _x: noised input related. (right part) + context_pre_only: last layer only do prenorm + modulation cuz no more ffn + """ + + def __init__( + self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False + ): + super().__init__() + + self.context_pre_only = context_pre_only + + self.attn_norm_c = ( + AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim) + ) + self.attn_norm_x = AdaLayerNormZero(dim) + self.attn = Attention( + processor=JointAttnProcessor(), + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + context_dim=dim, + context_pre_only=context_pre_only, + ) + + if not context_pre_only: + self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_c = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) + else: + self.ff_norm_c = None + self.ff_c = None + self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_x = FeedForward( + dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh" + ) + + def forward( + self, x, c, t, mask=None, rope=None, c_rope=None + ): # x: noised input, c: context, t: time embedding + # pre-norm & modulation for attention input + if self.context_pre_only: + norm_c = self.attn_norm_c(c, t) + else: + norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c( + c, emb=t + ) + norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x( + x, emb=t + ) + + # attention + x_attn_output, c_attn_output = self.attn( + x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope + ) + + # process attention output for context c + if self.context_pre_only: + c = None + else: # if not last layer + c = c + c_gate_msa.unsqueeze(1) * c_attn_output + + norm_c = ( + self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + ) + c_ff_output = self.ff_c(norm_c) + c = c + c_gate_mlp.unsqueeze(1) * c_ff_output + + # process attention output for input x + x = x + x_gate_msa.unsqueeze(1) * x_attn_output + + norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None] + x_ff_output = self.ff_x(norm_x) + x = x + x_gate_mlp.unsqueeze(1) * x_ff_output + + return c, x + + +# time step conditioning embedding + + +class TimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential( + nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim) + ) + + def forward(self, timestep: float["b"]): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + time = self.time_mlp(time_hidden) # b d + return time diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py new file mode 100755 index 000000000..354b421b6 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -0,0 +1,1192 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo) +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (authors: Yuekai Zhang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +world_size=8 +exp_dir=exp/ft-tts +""" + +import argparse +import copy +import logging +import os +import random +import warnings +from contextlib import nullcontext +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model.cfm import CFM +from model.dit import DiT +from model.utils import MelSpec +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import TtsDataModule + +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--decoder-dim", + type=int, + default=1024, + help="Embedding dimension in the decoder model.", + ) + parser.add_argument( + "--nhead", + type=int, + default=16, + help="Number of attention heads in the Decoder layers.", + ) + parser.add_argument( + "--num-decoder-layers", + type=int, + default=22, + help="Number of Decoder layers.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="exp/valle_dev", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="ft-tts/vocab.txt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--pretrained-model-path", + type=str, + default="/home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--optimizer-name", + type=str, + default="ScaledAdam", + help="The optimizer.", + ) + parser.add_argument( + "--scheduler-name", + type=str, + default="Eden", + help="The scheduler.", + ) + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + parser.add_argument( + "--warmup-steps", + type=int, + default=200, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train %% save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + parser.add_argument( + "--valid-interval", + type=int, + default=10000, + help="""Run validation if batch_idx %% valid_interval is 0.""", + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=0, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accumulate-grad-steps", + type=int, + default=1, + help="""update gradient when batch_idx_train %% accumulate_grad_steps == 0. + """, + ) + + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + help="Training dtype: float32 bfloat16 float16.", + ) + + parser.add_argument( + "--filter-min-duration", + type=float, + default=0.0, + help="Keep only utterances with duration > this.", + ) + parser.add_argument( + "--filter-max-duration", + type=float, + default=20.0, + help="Keep only utterances with duration < this.", + ) + + parser.add_argument( + "--visualize", + type=str2bool, + default=False, + help="visualize model results in eval step.", + ) + + parser.add_argument( + "--oom-check", + type=str2bool, + default=False, + help="perform OOM check on dataloader batches before starting training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 1, + "reset_interval": 200, + "valid_interval": 10000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_tokenizer(vocab_file_path: str): + """ + tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file + - "char" for char-wise tokenizer, need .txt vocab_file + - "byte" for utf-8 tokenizer + - "custom" if you're directly passing in a path to the vocab.txt you want to use + vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols + - if use "char", derived from unfiltered character & symbol counts of custom dataset + - if use "byte", set to 256 (unicode byte range) + """ + with open(vocab_file_path, "r", encoding="utf-8") as f: + vocab_char_map = {} + for i, char in enumerate(f): + vocab_char_map[char[:-1]] = i + vocab_size = len(vocab_char_map) + + return vocab_char_map, vocab_size + + +def get_model(params): + vocab_char_map, vocab_size = get_tokenizer(params.tokens) + n_mel_channels = 100 + n_fft = 1024 + sampling_rate = 24_000 + hop_length = 256 + win_length = 1024 + + model_cfg = { + "dim": params.decoder_dim, + "depth": params.num_decoder_layers, + "heads": params.nhead, + "ff_mult": 2, + "text_dim": 512, + "conv_layers": 4, + "checkpoint_activations": False, + } + model = CFM( + transformer=DiT( + **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels + ), + mel_spec_kwargs=dict( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + n_mel_channels=n_mel_channels, + target_sample_rate=sampling_rate, + mel_spec_type="bigvgan", + ), + odeint_kwargs=dict( + method="euler", + ), + vocab_char_map=vocab_char_map, + ) + return model + + +def load_pretrained_checkpoint( + model, ckpt_path, device: str = "cpu", dtype=torch.float32 +): + model = model.to(dtype) + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) + + checkpoint["model_state_dict"] = { + k.replace("ema_model.", ""): v + for k, v in checkpoint["ema_model_state_dict"].items() + if k not in ["initted", "step"] + } + + # patch for backward compatibility, 305e3ea + for key in [ + "mel_spec.mel_stft.mel_scale.fb", + "mel_spec.mel_stft.spectrogram.window", + ]: + if key in checkpoint["model_state_dict"]: + del checkpoint["model_state_dict"][key] + model.load_state_dict(checkpoint["model_state_dict"]) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + if isinstance(model, DDP): + raise ValueError("load_checkpoint before DDP") + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def prepare_input(batch: dict, tokenizer, device: torch.device): + """Parse batch data""" + print(batch.keys()) + print(batch) + text_inputs = batch["text"] + mel_spec = batch["mel"].permute(0, 2, 1) + mel_lengths = batch["mel_lengths"] + return text_inputs, mel_spec, mel_lengths + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + (mel_spec, text_inputs, mel_lengths) = prepare_input(batch, device) + # at entry, TextTokens is (N, P) + assert text_inputs.ndim == 2 + assert mel_spec.ndim == 3 + + with torch.set_grad_enabled(is_training): + loss, cond, pred = model(mel_spec, text=text_inputs, lens=mel_lengths) + assert loss.requires_grad == is_training + + info = MetricsTracker() + exit(0) + # with warnings.catch_warnings(): + # warnings.simplefilter("ignore") + # info["frames"] = (audio_features_lens).sum().item() + # info["utterances"] = text_tokens.size(0) + + # # # Note: We use reduction=sum while computing the loss. + # # info["loss"] = loss.detach().cpu().item() + # # for metric in metrics: + # # info[metric] = metrics[metric].detach().cpu().item() + # # del metrics + # # Note: We use reduction=sum while computing the loss. + # info["loss"] = loss.detach().cpu().item() * info["frames"] + + # for i in range(len(loss_list)): + # info[f"loss_{i}"] = loss_list[i].detach().cpu().item() * info["frames"] + # for i in range(len(acc_list)): + # info[f"acc_{i}"] = acc_list[i] * info["frames"] + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + if world_size > 1: + tot_loss.reduce(loss.device) + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + # if params.visualize: + # output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}") + # output_dir.mkdir(parents=True, exist_ok=True) + # if isinstance(model, DDP): + # model.module.visualize(predicts, batch, output_dir=output_dir) + # else: + # model.visualize(predicts, batch, output_dir=output_dir) + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + rng: + Random for selecting. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + tot_loss = MetricsTracker() + iter_dl = iter(train_dl) + + dtype, enabled = torch.float32, False + if params.dtype in ["bfloat16", "bf16"]: + dtype, enabled = torch.bfloat16, True + elif params.dtype in ["float16", "fp16"]: + dtype, enabled = torch.float16, True + + batch_idx = 0 + while True: + try: + batch = next(iter_dl) + except StopIteration: + logging.info("Reaches end of dataloader.") + break + + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["text"]) + + try: + with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled): + loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * ( + 1 / params.reset_interval + ) + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + scaler.scale(loss).backward() + if params.batch_idx_train >= params.accumulate_grad_steps: + if params.batch_idx_train % params.accumulate_grad_steps == 0: + if params.optimizer_name not in ["ScaledAdam", "Eve"]: + # Unscales the gradients of optimizer's assigned params in-place + scaler.unscale_(optimizer) + # Since the gradients of optimizer's assigned params are unscaled, clips as usual: + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + for k in range(params.accumulate_grad_steps): + if isinstance(scheduler, Eden): + scheduler.step_batch(params.batch_idx_train) + else: + scheduler.step() + + set_batch_count(model, params.batch_idx_train) + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.average_period > 0: + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = ( + scaler._scale.item() if params.dtype in ["float16", "fp16"] else 1.0 + ) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, train_loss[{loss_info}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + + ( + f", grad_scale: {cur_grad_scale}" + if params.dtype in ["float16", "fp16"] + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, + "train/current_", + params.batch_idx_train, + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.dtype in ["float16", "fp16"]: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if params.batch_idx_train % params.valid_interval == 0: + # Calculate validation loss in Rank 0 + model.eval() + logging.info("Computing validation loss") + with torch.cuda.amp.autocast(dtype=dtype): + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + model.train() + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances( + cuts: CutSet, min_duration: float, max_duration: float +) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 0.6 second and 20 seconds + if c.duration < min_duration or c.duration > max_duration: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + if params.train_stage: + tb_writer = SummaryWriter( + log_dir=f"{params.exp_dir}/tensorboard_stage{params.train_stage}" + ) + else: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info(f"Device: {device}") + tokenizer = get_tokenizer(params.tokens) + print("the class type of tokenizer is: ", type(tokenizer)) + logging.info(params) + + logging.info("About to create model") + + model = get_model(params) + model = load_pretrained_checkpoint(model, params.pretrained_model_path) + + model = model.to(device) + + with open(f"{params.exp_dir}/model.txt", "w") as f: + print(model) + print(model, file=f) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0 and params.average_period > 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + model_parameters = model.parameters() + + if params.optimizer_name == "ScaledAdam": + optimizer = ScaledAdam( + model_parameters, + lr=params.base_lr, + clipping_scale=2.0, + ) + elif params.optimizer_name == "AdamW": + optimizer = torch.optim.AdamW( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + weight_decay=1e-2, + eps=1e-8, + ) + elif params.optimizer_name == "Adam": + optimizer = torch.optim.Adam( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + eps=1e-8, + ) + else: + raise NotImplementedError() + + scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) + + optimizer.zero_grad() + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.inf_check: + register_inf_check_hooks(model) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + dataset = TtsDataModule(args) + train_cuts = dataset.train_cuts() + valid_cuts = dataset.dev_cuts() + + train_cuts = filter_short_and_long_utterances( + train_cuts, params.filter_min_duration, params.filter_max_duration + ) + valid_cuts = filter_short_and_long_utterances( + valid_cuts, params.filter_min_duration, params.filter_max_duration + ) + + train_dl = dataset.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + valid_dl = dataset.dev_dataloaders(valid_cuts) + + if params.oom_check: + scan_pessimistic_batches_for_oom( + model=model, + tokenizer=tokenizer, + train_dl=train_dl, + optimizer=optimizer, + params=params, + ) + + scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + if isinstance(scheduler, Eden): + scheduler.step_epoch(epoch - 1) + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + tokenizer, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + + dtype = torch.float32 + if params.dtype in ["bfloat16", "bf16"]: + dtype = torch.bfloat16 + elif params.dtype in ["float16", "fp16"]: + dtype = torch.float16 + + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(dtype=dtype): + _, loss, _ = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + TtsDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py new file mode 100644 index 000000000..7a665a54c --- /dev/null +++ b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py @@ -0,0 +1,346 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from fbank import MatchaFbank, MatchaFbankConfig +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class TtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + parser.add_argument( + "--prefix", + type=str, + default="wenetspeech4tts", + help="prefix of the manifest file", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + pin_memory=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=True, + pin_memory=True, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"{self.args.prefix}_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"{self.args.prefix}_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / f"{self.args.prefix}_cuts_test.jsonl.gz" + ) diff --git a/egs/wenetspeech4tts/TTS/local/audio.py b/egs/wenetspeech4tts/TTS/local/audio.py new file mode 100644 index 000000000..b643e3de0 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/audio.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import math +import os +import pathlib +import random +from typing import List, Optional, Tuple + +import librosa +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from tqdm import tqdm + +# from env import AttrDict + +MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases) + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + return dynamic_range_compression_torch(magnitudes) + + +def spectral_de_normalize_torch(magnitudes): + return dynamic_range_decompression_torch(magnitudes) + + +mel_basis_cache = {} +hann_window_cache = {} + + +def mel_spectrogram( + y: torch.Tensor, + n_fft: int = 1024, + num_mels: int = 100, + sampling_rate: int = 24_000, + hop_size: int = 256, + win_size: int = 1024, + fmin: int = 0, + fmax: int = None, + center: bool = False, +) -> torch.Tensor: + """ + Calculate the mel spectrogram of an input signal. + This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). + + Args: + y (torch.Tensor): Input signal. + n_fft (int): FFT size. + num_mels (int): Number of mel bins. + sampling_rate (int): Sampling rate of the input signal. + hop_size (int): Hop size for STFT. + win_size (int): Window size for STFT. + fmin (int): Minimum frequency for mel filterbank. + fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn + center (bool): Whether to pad the input to center the frames. Default is False. + + Returns: + torch.Tensor: Mel spectrogram. + """ + if torch.min(y) < -1.0: + print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") + if torch.max(y) > 1.0: + print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") + + device = y.device + key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" + + if key not in mel_basis_cache: + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) + hann_window_cache[key] = torch.hann_window(win_size).to(device) + + mel_basis = mel_basis_cache[key] + hann_window = hann_window_cache[key] + + padding = (n_fft - hop_size) // 2 + y = torch.nn.functional.pad( + y.unsqueeze(1), (padding, padding), mode="reflect" + ).squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) + + mel_spec = torch.matmul(mel_basis, spec) + mel_spec = spectral_normalize_torch(mel_spec) + + return mel_spec diff --git a/egs/wenetspeech4tts/TTS/local/compute_mel_feat.py b/egs/wenetspeech4tts/TTS/local/compute_mel_feat.py new file mode 100755 index 000000000..5292c75ad --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/compute_mel_feat.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LJSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from fbank import MatchaFbank, MatchaFbankConfig +from lhotse import CutSet, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=1, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--src-dir", + type=Path, + default=Path("data/manifests"), + help="Path to the manifest files", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("data/fbank"), + help="Path to the tokenized files", + ) + + parser.add_argument( + "--dataset-parts", + type=str, + default="Basic", + help="Space separated dataset parts", + ) + + parser.add_argument( + "--prefix", + type=str, + default="wenetspeech4tts", + help="prefix of the manifest file", + ) + + parser.add_argument( + "--suffix", + type=str, + default="jsonl.gz", + help="suffix of the manifest file", + ) + + parser.add_argument( + "--split", + type=int, + default=100, + help="Split the cut_set into multiple parts", + ) + + parser.add_argument( + "--resample-to-24kHz", + default=True, + help="Resample the audio to 24kHz", + ) + + parser.add_argument( + "--extractor", + type=str, + choices=["bigvgan", "hifigan"], + default="bigvgan", + help="The type of extractor to use", + ) + return parser + + +def compute_fbank(args): + src_dir = Path(args.src_dir) + output_dir = Path(args.output_dir) + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + num_jobs = min(args.num_jobs, os.cpu_count()) + dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip().split(" ") + + logging.info(f"num_jobs: {num_jobs}") + logging.info(f"src_dir: {src_dir}") + logging.info(f"output_dir: {output_dir}") + logging.info(f"dataset_parts: {dataset_parts}") + if args.extractor == "bigvgan": + config = MatchaFbankConfig( + n_fft=1024, + n_mels=100, + sampling_rate=24_000, + hop_length=256, + win_length=1024, + f_min=0, + f_max=None, + ) + elif args.extractor == "hifigan": + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=22050, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + else: + raise NotImplementedError(f"Extractor {args.extractor} is not implemented") + + extractor = MatchaFbank(config) + + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=args.src_dir, + prefix=args.prefix, + suffix=args.suffix, + types=["recordings", "supervisions", "cuts"], + ) + + with get_executor() as ex: + for partition, m in manifests.items(): + logging.info( + f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}" + ) + try: + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + except Exception: + cut_set = m["cuts"] + + if args.split > 1: + cut_sets = cut_set.split(args.split) + else: + cut_sets = [cut_set] + + for idx, part in enumerate(cut_sets): + if args.split > 1: + storage_path = f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}_{idx}" + else: + storage_path = ( + f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}" + ) + + if args.resample_to_24kHz: + part = part.resample(24000) + + with torch.no_grad(): + part = part.compute_and_store_features( + extractor=extractor, + storage_path=storage_path, + num_jobs=num_jobs if ex is None else 64, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + if args.split > 1: + cuts_filename = ( + f"{args.prefix}_cuts_{partition}.{idx}.{args.suffix}" + ) + else: + cuts_filename = f"{args.prefix}_cuts_{partition}.{args.suffix}" + + part.to_file(f"{args.output_dir}/{cuts_filename}") + logging.info(f"Saved {cuts_filename}") + + +if __name__ == "__main__": + # Torch's multithreaded behavior needs to be disabled or + # it wastes a lot of CPU and slow things down. + # Do this outside of main() in case it needs to take effect + # even when we are not invoking the main (e.g. when spawning subprocesses). + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_parser().parse_args() + compute_fbank(args) diff --git a/egs/wenetspeech4tts/TTS/local/fbank.py b/egs/wenetspeech4tts/TTS/local/fbank.py new file mode 120000 index 000000000..3cfb7fe3f --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/fbank.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/fbank.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/prepare.sh b/egs/wenetspeech4tts/TTS/prepare.sh index 54e140dbb..7b800d87e 100755 --- a/egs/wenetspeech4tts/TTS/prepare.sh +++ b/egs/wenetspeech4tts/TTS/prepare.sh @@ -1,12 +1,14 @@ #!/usr/bin/env bash - +export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha set -eou pipefail + + # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python -stage=1 -stop_stage=4 +stage=7 +stop_stage=7 dl_dir=$PWD/download @@ -98,3 +100,57 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir} fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: build monotonic_align lib (used by matcha recipes)" + for recipe in matcha; do + if [ ! -d $recipe/monotonic_align/build ]; then + cd $recipe/monotonic_align + python3 setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib for $recipe already built" + fi + done +fi + +subset="Basic" +prefix="wenetspeech4tts" +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate fbank (used by ./matcha)" + mkdir -p data/fbank + if [ ! -e data/fbank/.${prefix}.done ]; then + ./local/compute_mel_feat.py --dataset-parts $subset --split 100 + touch data/fbank/.${prefix}.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./matcha)" + if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then + echo "Combining ${prefix} cuts" + pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz") + # lhotse combine $pieces data/fbank/${prefix}_cuts_${subset}.jsonl.gz + fi + if [ ! -e data/fbank/.${prefix}_split.done ]; then + echo "Splitting ${prefix} cuts into train, valid and test sets" + + # lhotse subset --last 800 \ + # data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ + # data/fbank/${prefix}_cuts_validtest.jsonl.gz + # lhotse subset --first 400 \ + # data/fbank/${prefix}_cuts_validtest.jsonl.gz \ + # data/fbank/${prefix}_cuts_valid.jsonl.gz + # lhotse subset --last 400 \ + # data/fbank/${prefix}_cuts_validtest.jsonl.gz \ + # data/fbank/${prefix}_cuts_test.jsonl.gz + + # rm data/fbank/${prefix}_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 )) + lhotse subset --first $n \ + data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ + data/fbank/${prefix}_cuts_train.jsonl.gz + touch data/fbank/.${prefix}_split.done + fi +fi