diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py index 5c25c3cf4..fee66da48 100755 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -27,28 +27,100 @@ The generated fbank features are saved in data/fbank. import argparse import logging import os +from dataclasses import dataclass from pathlib import Path +from typing import Union +import numpy as np import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - load_manifest, - load_manifest_lazy, -) +from lhotse import CutSet, LilcomChunkyWriter, load_manifest from lhotse.audio import RecordingSet +from lhotse.features.base import FeatureExtractor, register_extractor from lhotse.supervision import SupervisionSet +from lhotse.utils import Seconds, compute_num_frames +from matcha.utils.audio import mel_spectrogram from icefall.utils import get_executor -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) + +@dataclass +class MyFbankConfig: + n_fft: int + n_mels: int + sampling_rate: int + hop_length: int + win_length: int + f_min: float + f_max: float + + +@register_extractor +class MyFbank(FeatureExtractor): + + name = "MyFbank" + config_type = MyFbankConfig + + def __init__(self, config): + super().__init__(config=config) + + @property + def device(self) -> Union[str, torch.device]: + return self.config.device + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.n_mels + + def extract( + self, + samples: np.ndarray, + sampling_rate: int, + ) -> torch.Tensor: + # Check for sampling rate compatibility. + expected_sr = self.config.sampling_rate + assert sampling_rate == expected_sr, ( + f"Mismatched sampling rate: extractor expects {expected_sr}, " + f"got {sampling_rate}" + ) + samples = torch.from_numpy(samples) + assert samples.ndim == 2, samples.shape + assert samples.shape[0] == 1, samples.shape + + mel = ( + mel_spectrogram( + samples, + self.config.n_fft, + self.config.n_mels, + self.config.sampling_rate, + self.config.hop_length, + self.config.win_length, + self.config.f_min, + self.config.f_max, + center=False, + ) + .squeeze() + .t() + ) + + assert mel.ndim == 2, mel.shape + assert mel.shape[1] == self.config.n_mels, mel.shape + + num_frames = compute_num_frames( + samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate + ) + + if mel.shape[0] > num_frames: + mel = mel[:num_frames] + elif mel.shape[0] < num_frames: + mel = mel.unsqueeze(0) + mel = torch.nn.functional.pad( + mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" + ).squeeze(0) + + return mel.numpy() + + @property + def frame_shift(self) -> Seconds: + return self.config.hop_length / self.config.sampling_rate def get_parser(): @@ -77,10 +149,15 @@ def compute_fbank_ljspeech(num_jobs: int): logging.info(f"num_jobs: {num_jobs}") logging.info(f"src_dir: {src_dir}") logging.info(f"output_dir: {output_dir}") - - sampling_rate = 22050 - frame_length = 1024 / sampling_rate # (in second) - frame_shift = 256 / sampling_rate # (in second) + config = MyFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=22050, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) prefix = "ljspeech" suffix = "jsonl.gz" @@ -93,25 +170,7 @@ def compute_fbank_ljspeech(num_jobs: int): src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet ) - # Differences with matcha-tts - # 1. we use pre-emphasis - # 2. we remove dc offset - # 3. we use a different window - # 4. we use a different mel filter bank matrix - # 5. we don't normalize features - config = FbankConfig( - sampling_rate=sampling_rate, - frame_length=frame_length, - frame_shift=frame_shift, - use_fft_mag=True, - low_freq=0, - high_freq=8000, - remove_dc_offset=False, - preemph_coeff=0, - # should be identical to n_feats in ../matcha/train.py - num_filters=80, - ) - extractor = Fbank(config) + extractor = MyFbank(config) with get_executor() as ex: # Initialize the executor only once. cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" @@ -135,6 +194,12 @@ def compute_fbank_ljspeech(num_jobs: int): if __name__ == "__main__": + # Torch's multithreaded behavior needs to be disabled or + # it wastes a lot of CPU and slow things down. + # Do this outside of main() in case it needs to take effect + # even when we are not invoking the main (e.g. when spawning subprocesses). + torch.set_num_threads(1) + torch.set_num_interop_threads(1) formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py index 68159ae03..bbd1bfe9d 100755 --- a/egs/ljspeech/TTS/local/validate_manifest.py +++ b/egs/ljspeech/TTS/local/validate_manifest.py @@ -35,6 +35,7 @@ from pathlib import Path from lhotse import CutSet, load_manifest_lazy from lhotse.dataset.speech_synthesis import validate_for_tts +from compute_fbank_ljspeech import MyFbank def get_args(): diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 747292197..7f41ab101 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -3,18 +3,17 @@ import argparse +import json import logging from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Union -import json import k2 import torch import torch.multiprocessing as mp import torch.nn as nn from lhotse.utils import fix_random_seed -from matcha.data.text_mel_datamodule import TextMelDataModule from matcha.models.matcha_tts import MatchaTTS from matcha.tokenizer import Tokenizer from matcha.utils.model import fix_len_compatibility @@ -355,36 +354,27 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): - if "tokens" in batch: + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device, params) + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) - losses = get_losses( - { - "x": tokens, - "x_lengths": tokens_lens, - "y": features.permute(0, 2, 1), - "y_lengths": features_lens, - "spks": None, # should change it for multi-speakers - "durations": None, - } - ) - - batch_size = len(batch["tokens"]) - else: - batch_size = batch["x"].shape[0] - batch["x"] = batch["x"].to(device) - batch["x_lengths"] = batch["x_lengths"].to(device) - batch["y"] = batch["y"].to(device) - batch["y_lengths"] = batch["y_lengths"].to(device) - losses = get_losses(batch) + batch_size = len(batch["tokens"]) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -478,38 +468,28 @@ def train_one_epoch( # features_lens, (N,), int32 # tokens: List[List[str]], len(tokens) == N - if "tokens" in batch: - batch_size = len(batch["tokens"]) + batch_size = len(batch["tokens"]) - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device, params) - else: - batch_size = batch["x"].shape[0] + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) try: with autocast(enabled=params.use_fp16): - if "tokens" in batch: - losses = get_losses( - { - "x": tokens, - "x_lengths": tokens_lens, - "y": features.permute(0, 2, 1), - "y_lengths": features_lens, - "spks": None, # should change it for multi-speakers - "durations": None, - } - ) - else: - batch["x"] = batch["x"].to(device) - batch["x_lengths"] = batch["x_lengths"].to(device) - batch["y"] = batch["y"].to(device) - batch["y_lengths"] = batch["y_lengths"].to(device) - losses = get_losses(batch) + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) loss = sum(losses.values()) @@ -535,8 +515,9 @@ def train_one_epoch( raise if params.batch_idx_train % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different + # If the grad scale was less than 1, try increasing it. + # The _growth_interval of the grad scaler is configurable, + # but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() @@ -560,7 +541,8 @@ def train_one_epoch( logging.info( f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"global_batch_idx: {params.batch_idx_train}, " + f"batch size: {batch_size}, " f"loss[{loss_info}], tot_loss[{tot_loss}], " + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) @@ -588,7 +570,8 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + "Maximum memory allocated so far is " + f"{torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -658,20 +641,13 @@ def run(rank, world_size, args): logging.info("About to create datamodule") - if False: - params.data_args.tokenizer = tokenizer - data_module = TextMelDataModule(hparams=params.data_args) - del params.data_args.tokenizer - train_dl = data_module.train_dataloader() - valid_dl = data_module.val_dataloader() - else: - ljspeech = LJSpeechTtsDataModule(args) + ljspeech = LJSpeechTtsDataModule(args) - train_cuts = ljspeech.train_cuts() - train_dl = ljspeech.train_dataloaders(train_cuts) + train_cuts = ljspeech.train_cuts() + train_dl = ljspeech.train_dataloaders(train_cuts) - valid_cuts = ljspeech.valid_cuts() - valid_dl = ljspeech.valid_dataloaders(valid_cuts) + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: