diff --git a/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py new file mode 120000 index 000000000..85255ba0c --- /dev/null +++ b/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py @@ -0,0 +1 @@ +../local/compute_fbank_ljspeech.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py new file mode 100755 index 000000000..c56e2da89 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 + +import json +import logging + +import torch +from inference import get_parser +from tokenizer import Tokenizer +from train import get_model, get_params +from icefall.checkpoint import load_checkpoint +from onnxruntime.quantization import QuantType, quantize_dynamic + + +class ModelWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + temperature: torch.Tensor, + length_scale: torch.Tensor, + ) -> torch.Tensor: + """ + Args: : + x: (batch_size, num_tokens), torch.int64 + x_lengths: (batch_size,), torch.int64 + temperature: (1,), torch.float32 + length_scale (1,), torch.float32 + Returns: + mel: (batch_size, feat_dim, num_frames) + + """ + mel = self.model.synthesise( + x=x, + x_lengths=x_lengths, + n_timesteps=3, + temperature=temperature, + length_scale=length_scale, + )["mel"] + + # mel: (batch_size, feat_dim, num_frames) + + return mel + + +@torch.inference_mode +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + wrapper = ModelWrapper(model) + wrapper.eval() + + # Use a large value so the the rotary position embedding in the text + # encoder has a large initial length + x = torch.ones(1, 2000, dtype=torch.int64) + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + temperature = torch.tensor([1.0]) + length_scale = torch.tensor([1.0]) + mel = wrapper(x, x_lengths, temperature, length_scale) + print("mel", mel.shape) + + opset_version = 14 + filename = "model.onnx" + torch.onnx.export( + wrapper, + (x, x_lengths, temperature, length_scale), + filename, + opset_version=opset_version, + input_names=["x", "x_length", "temperature", "length_scale"], + output_names=["mel"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "x_length": {0: "N"}, + "mel": {0: "N", 2: "L"}, + }, + ) + + print("Generate int8 quantization models") + + filename_int8 = "model.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + weight_type=QuantType.QInt8, + ) + + print(f"Saved to {filename} and {filename_int8}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py index 45d73bf4f..49c9c708a 100755 --- a/egs/ljspeech/TTS/matcha/inference.py +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -5,6 +5,7 @@ import datetime as dt import logging from pathlib import Path +import json import numpy as np import soundfile as sf import torch @@ -29,7 +30,7 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=1320, + default=2810, help="""It specifies the checkpoint to use for decoding. Note: Epoch counts from 1. """, @@ -38,7 +39,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=Path, - default="matcha/exp-fbank", + default="matcha/exp-new-3", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -51,6 +52,13 @@ def get_parser(): default="data/tokens.txt", ) + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + return parser @@ -111,13 +119,21 @@ def main(): params = get_params() params.update(vars(args)) - logging.info(params) tokenizer = Tokenizer(params.tokens) params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size params.model_args.n_vocab = params.vocab_size + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + logging.info(params) + logging.info("About to create model") model = get_model(params) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) @@ -127,9 +143,9 @@ def main(): denoiser = Denoiser(vocoder, mode="zeros") texts = [ - "How are you doing, my friend", - # "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", - # "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", + "How are you doing? my friend.", + "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", + "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", ] # Number of ODE Solver steps @@ -174,7 +190,7 @@ def main(): rtfs_w.append(rtf_w) # Save the generated waveform - save_to_folder(i, output, folder="./my-output-1320") + save_to_folder(i, output, folder=f"./my-output-{params.epoch}") print(f"Number of ODE steps: {n_timesteps}") print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}") diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py new file mode 100755 index 000000000..1a973bcff --- /dev/null +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +import logging + +import onnxruntime as ort +import torch +from tokenizer import Tokenizer + +from inference import load_vocoder +import soundfile as sf + + +class OnnxModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.tokenizer = Tokenizer("./data/tokens.txt") + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 2, x.shape + assert x.shape[0] == 1, x.shape + + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + print("x_lengths", x_lengths) + print("x", x.shape) + + temperature = torch.tensor([1.0], dtype=torch.float32) + length_scale = torch.tensor([1.0], dtype=torch.float32) + + mel = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lengths.numpy(), + self.model.get_inputs()[2].name: temperature.numpy(), + self.model.get_inputs()[3].name: length_scale.numpy(), + }, + )[0] + + return torch.from_numpy(mel) + + +@torch.inference_mode() +def main(): + model = OnnxModel("./model.onnx") + text = "hello, how are you doing?" + text = "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." + x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) + x = torch.tensor(x, dtype=torch.int64) + mel = model(x) + print("mel", mel.shape) # (1, 80, 170) + + vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1") + audio = vocoder(mel).clamp(-1, 1) + print("audio", audio.shape) # (1, 1, num_samples) + audio = audio.squeeze() + + # skip denoiser + sf.write("onnx.wav", audio, 22050, "PCM_16") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index edf7e1eef..bb9307864 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -7,6 +7,7 @@ import logging from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Union +import json import k2 import torch @@ -90,6 +91,13 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + parser.add_argument( "--seed", type=int, @@ -123,11 +131,8 @@ def get_parser(): def get_data_statistics(): return AttributeDict( { - # "mel_mean": -5.517028331756592, # matcha-tts - # "mel_std": 2.0643954277038574, - # ours - "mel_mean": -1.168782114982605, - "mel_std": 1.9283572435379028, + "mel_mean": 0, + "mel_std": 1, } ) @@ -138,9 +143,9 @@ def _get_data_params() -> AttributeDict: "name": "ljspeech", "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", - "batch_size": 64, - "num_workers": 1, - "pin_memory": False, + # "batch_size": 64, + # "num_workers": 1, + # "pin_memory": False, "cleaners": ["english_cleaners2"], "add_blank": True, "n_spks": 1, @@ -312,7 +317,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, param tokens = batch["tokens"] tokens = tokenizer.tokens_to_token_ids( - tokens, intersperse_blank=True, add_sos=False, add_eos=False + tokens, intersperse_blank=True, add_sos=True, add_eos=True ) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) @@ -619,10 +624,17 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id + params.pad_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size params.model_args.n_vocab = params.vocab_size - params.model_args.n_vocab = 178 + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] logging.info(params) print(params) diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py index 0fc16366e..0227d9fdb 100644 --- a/egs/ljspeech/TTS/matcha/tts_datamodule.py +++ b/egs/ljspeech/TTS/matcha/tts_datamodule.py @@ -24,7 +24,8 @@ from pathlib import Path from typing import Any, Dict, Optional import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse import CutSet, load_manifest_lazy +from compute_fbank_ljspeech import MyFbank, MyFbankConfig from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, @@ -176,22 +177,19 @@ class LJSpeechTtsDataModule: if self.args.on_the_fly_feats: sampling_rate = 22050 - config = FbankConfig( + config = MyFbankConfig( + n_fft=1024, + n_mels=80, sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - remove_dc_offset=False, - preemph_coeff=0, - low_freq=0, - high_freq=8000, - # should be identical to n_feats in ./train.py - num_filters=80, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, ) train = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), return_cuts=self.args.return_cuts, ) @@ -229,7 +227,8 @@ class LJSpeechTtsDataModule: sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, - persistent_workers=False, + persistent_workers=True, + pin_memory=True, worker_init_fn=worker_init_fn, ) @@ -239,22 +238,19 @@ class LJSpeechTtsDataModule: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: sampling_rate = 22050 - config = FbankConfig( + config = MyFbankConfig( + n_fft=1024, + n_mels=80, sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - remove_dc_offset=False, - preemph_coeff=0, - low_freq=0, - high_freq=8000, - # should be identical to n_feats in ./train.py - num_filters=80, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, ) validate = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), return_cuts=self.args.return_cuts, ) else: @@ -276,7 +272,8 @@ class LJSpeechTtsDataModule: sampler=valid_sampler, batch_size=None, num_workers=2, - persistent_workers=False, + persistent_workers=True, + pin_memory=True, ) return valid_dl @@ -285,22 +282,19 @@ class LJSpeechTtsDataModule: logging.info("About to create test dataset") if self.args.on_the_fly_feats: sampling_rate = 22050 - config = FbankConfig( + config = MyFbankConfig( + n_fft=1024, + n_mels=80, sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - remove_dc_offset=False, - preemph_coeff=0, - low_freq=0, - high_freq=8000, - # should be identical to n_feats in ./train.py - num_filters=80, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, ) test = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), return_cuts=self.args.return_cuts, ) else: