diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py index 3aeb6add7..5c25c3cf4 100755 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -106,6 +106,8 @@ def compute_fbank_ljspeech(num_jobs: int): use_fft_mag=True, low_freq=0, high_freq=8000, + remove_dc_offset=False, + preemph_coeff=0, # should be identical to n_feats in ../matcha/train.py num_filters=80, ) diff --git a/egs/ljspeech/TTS/local/compute_fbank_statistics.py b/egs/ljspeech/TTS/local/compute_fbank_statistics.py new file mode 100755 index 000000000..d0232c983 --- /dev/null +++ b/egs/ljspeech/TTS/local/compute_fbank_statistics.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script compute the mean and std of the fbank features. +""" + +import argparse +import json +import logging +from pathlib import Path + +import torch +from lhotse import CutSet, load_manifest_lazy + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + parser.add_argument( + "cmvn", + type=Path, + help="Path to the cmvn.json", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info( + f"Computing fbank mean and std for {manifest} and saving to {args.cmvn}" + ) + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet), type(cut_set) + + feat_dim = cut_set[0].features.num_features + num_frames = 0 + s = 0 + sq = 0 + for c in cut_set: + f = torch.from_numpy(c.load_features()) + num_frames += f.shape[0] + s += f.sum() + sq += f.square().sum() + + fbank_mean = s / (num_frames * feat_dim) + fbank_var = sq / (num_frames * feat_dim) - fbank_mean * fbank_mean + print("fbank var", fbank_var) + fbank_std = fbank_var.sqrt() + with open(args.cmvn, "w") as f: + json.dump({"fbank_mean": fbank_mean.item(), "fbank_std": fbank_std.item()}, f) + f.write("\n") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 29a0f53a8..45d73bf4f 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -10,6 +10,7 @@ import soundfile as sf import torch from matcha.hifigan.config import v1 from matcha.hifigan.denoiser import Denoiser +from tokenizer import Tokenizer from matcha.hifigan.models import Generator as HiFiGAN from matcha.text import sequence_to_text, text_to_sequence from matcha.utils.utils import intersperse @@ -28,7 +29,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=140, + default=1320, help="""It specifies the checkpoint to use for decoding. Note: Epoch counts from 1. """, @@ -37,13 +38,19 @@ def get_parser(): parser.add_argument( "--exp-dir", type=Path, - default="matcha/exp", + default="matcha/exp-fbank", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, ) + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + return parser @@ -71,19 +78,17 @@ def save_to_folder(filename: str, output: dict, folder: str): sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24") -def process_text(text: str): - x = torch.tensor( - intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0), - dtype=torch.long, - device="cpu", - )[None] +def process_text(text: str, tokenizer): + x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) + x = torch.tensor(x, dtype=torch.long) x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") - x_phones = sequence_to_text(x.squeeze(0).tolist()) - return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones} + return {"x_orig": text, "x": x, "x_lengths": x_lengths} -def synthesise(model, n_timesteps, text, length_scale, temperature, spks=None): - text_processed = process_text(text) +def synthesise( + model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None +): + text_processed = process_text(text, tokenizer) start_t = dt.datetime.now() output = model.synthesise( text_processed["x"], @@ -108,6 +113,11 @@ def main(): params.update(vars(args)) logging.info(params) + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + logging.info("About to create model") model = get_model(params) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) @@ -117,12 +127,13 @@ def main(): denoiser = Denoiser(vocoder, mode="zeros") texts = [ - "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", - "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", + "How are you doing, my friend", + # "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + # "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", ] # Number of ODE Solver steps - n_timesteps = 2 + n_timesteps = 3 # Changes to the speaking rate length_scale = 1.0 @@ -135,6 +146,7 @@ def main(): for i, text in enumerate(tqdm(texts)): output = synthesise( model=model, + tokenizer=tokenizer, n_timesteps=n_timesteps, text=text, length_scale=length_scale, @@ -154,7 +166,7 @@ def main(): print(f"{'*' * 53}") print(f"Phonetised text - {i}") print(f"{'-' * 53}") - print(output["x_phones"]) + print(output["x"]) print(f"{'*' * 53}") print(f"RTF:\t\t{output['rtf']:.6f}") print(f"RTF Waveform:\t{rtf_w:.6f}") @@ -162,7 +174,7 @@ def main(): rtfs_w.append(rtf_w) # Save the generated waveform - save_to_folder(i, output, folder="./my-output") + save_to_folder(i, output, folder="./my-output-1320") print(f"Number of ODE steps: {n_timesteps}") print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}") diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 94e089d7e..edf7e1eef 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -13,6 +13,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from lhotse.utils import fix_random_seed +from matcha.data.text_mel_datamodule import TextMelDataModule from matcha.models.matcha_tts import MatchaTTS from matcha.tokenizer import Tokenizer from matcha.utils.model import fix_len_compatibility @@ -122,8 +123,11 @@ def get_parser(): def get_data_statistics(): return AttributeDict( { - "mel_mean": 0.0, - "mel_std": 1.0, + # "mel_mean": -5.517028331756592, # matcha-tts + # "mel_std": 2.0643954277038574, + # ours + "mel_mean": -1.168782114982605, + "mel_std": 1.9283572435379028, } ) @@ -134,7 +138,8 @@ def _get_data_params() -> AttributeDict: "name": "ljspeech", "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", - "num_workers": 3, + "batch_size": 64, + "num_workers": 1, "pin_memory": False, "cleaners": ["english_cleaners2"], "add_blank": True, @@ -289,8 +294,17 @@ def load_checkpoint_if_available( return saved_params -def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params): """Parse batch data""" + mel_mean = params.data_args.data_statistics.mel_mean + mel_std_inv = 1 / params.data_args.data_statistics.mel_std + for i in range(batch["features"].shape[0]): + n = batch["features_lens"][i] + batch["features"][i : i + 1, :n, :] = ( + batch["features"][i : i + 1, :n, :] - mel_mean + ) * mel_std_inv + batch["features"][i : i + 1, n:, :] = 0 + audio = batch["audio"].to(device) features = batch["features"].to(device) audio_lens = batch["audio_lens"].to(device) @@ -298,7 +312,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): tokens = batch["tokens"] tokens = tokenizer.tokens_to_token_ids( - tokens, intersperse_blank=True, add_sos=True, add_eos=True + tokens, intersperse_blank=True, add_sos=False, add_eos=False ) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) @@ -315,7 +329,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): # features_lens[features_lens.argmax()] += pad - return audio, audio_lens, features, features_lens, tokens, tokens_lens + return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long() def compute_validation_loss( @@ -336,28 +350,36 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): + if "tokens" in batch: - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) - losses = get_losses( - { - "x": tokens, - "x_lengths": tokens_lens, - "y": features.permute(0, 2, 1), - "y_lengths": features_lens, - "spks": None, # should change it for multi-speakers - "durations": None, - } - ) + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) - batch_size = len(batch["tokens"]) + batch_size = len(batch["tokens"]) + else: + batch_size = batch["x"].shape[0] + batch["x"] = batch["x"].to(device) + batch["x_lengths"] = batch["x_lengths"].to(device) + batch["y"] = batch["y"].to(device) + batch["y_lengths"] = batch["y_lengths"].to(device) + losses = get_losses(batch) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -451,24 +473,38 @@ def train_one_epoch( # features_lens, (N,), int32 # tokens: List[List[str]], len(tokens) == N - batch_size = len(batch["tokens"]) - - audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( - batch, tokenizer, device - ) + if "tokens" in batch: + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) + else: + batch_size = batch["x"].shape[0] try: with autocast(enabled=params.use_fp16): - losses = get_losses( - { - "x": tokens, - "x_lengths": tokens_lens, - "y": features.permute(0, 2, 1), - "y_lengths": features_lens, - "spks": None, # should change it for multi-speakers - "durations": None, - } - ) + if "tokens" in batch: + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + else: + batch["x"] = batch["x"].to(device) + batch["x_lengths"] = batch["x_lengths"].to(device) + batch["y"] = batch["y"].to(device) + batch["y_lengths"] = batch["y_lengths"].to(device) + losses = get_losses(batch) loss = sum(losses.values()) @@ -586,6 +622,7 @@ def run(rank, world_size, args): params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size params.model_args.n_vocab = params.vocab_size + params.model_args.n_vocab = 178 logging.info(params) print(params) @@ -595,7 +632,6 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of parameters: {num_param}") - print(f"Number of parameters: {num_param}") assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available(params=params, model=model) @@ -609,13 +645,21 @@ def run(rank, world_size, args): optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) logging.info("About to create datamodule") - ljspeech = LJSpeechTtsDataModule(args) - train_cuts = ljspeech.train_cuts() - train_dl = ljspeech.train_dataloaders(train_cuts) + if False: + params.data_args.tokenizer = tokenizer + data_module = TextMelDataModule(hparams=params.data_args) + del params.data_args.tokenizer + train_dl = data_module.train_dataloader() + valid_dl = data_module.val_dataloader() + else: + ljspeech = LJSpeechTtsDataModule(args) - valid_cuts = ljspeech.valid_cuts() - valid_dl = ljspeech.valid_dataloaders(valid_cuts) + train_cuts = ljspeech.train_cuts() + train_dl = ljspeech.train_dataloaders(train_cuts) + + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: @@ -625,7 +669,8 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs + 1): logging.info(f"Start epoch {epoch}") fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) + if "sampler" in train_dl: + train_dl.sampler.set_epoch(epoch - 1) params.cur_epoch = epoch diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py index c2be815d9..0fc16366e 100644 --- a/egs/ljspeech/TTS/matcha/tts_datamodule.py +++ b/egs/ljspeech/TTS/matcha/tts_datamodule.py @@ -181,6 +181,8 @@ class LJSpeechTtsDataModule: frame_length=1024 / sampling_rate, # (in second), frame_shift=256 / sampling_rate, # (in second) use_fft_mag=True, + remove_dc_offset=False, + preemph_coeff=0, low_freq=0, high_freq=8000, # should be identical to n_feats in ./train.py @@ -242,6 +244,8 @@ class LJSpeechTtsDataModule: frame_length=1024 / sampling_rate, # (in second), frame_shift=256 / sampling_rate, # (in second) use_fft_mag=True, + remove_dc_offset=False, + preemph_coeff=0, low_freq=0, high_freq=8000, # should be identical to n_feats in ./train.py @@ -286,6 +290,8 @@ class LJSpeechTtsDataModule: frame_length=1024 / sampling_rate, # (in second), frame_shift=256 / sampling_rate, # (in second) use_fft_mag=True, + remove_dc_offset=False, + preemph_coeff=0, low_freq=0, high_freq=8000, # should be identical to n_feats in ./train.py diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index e1cd0897e..b140e6f01 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -191,3 +191,10 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then touch data/fbank/.ljspeech_split.done fi fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compute fbank mean and std (used by ./matcha)" + if [ ! -f ./data/fbank/cmvn.json ]; then + ./local/compute_fbank_statistics.py ./data/fbank/ljspeech_cuts_train.jsonl.gz ./data/fbank/cmvn.json + fi +fi