From 9681263c0d4cf86f83431f63e253f9ca2659beb7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 12 Mar 2024 16:39:14 +0800 Subject: [PATCH] Update README --- egs/ljspeech/TTS/README.md | 16 ++++++++++++++++ egs/ljspeech/TTS/vits/export-onnx.py | 12 ++++++++---- egs/ljspeech/TTS/vits/test_onnx.py | 27 +++++++++++++++++++++++---- egs/ljspeech/TTS/vits/tokenizer.py | 10 +++++++++- 4 files changed, 56 insertions(+), 9 deletions(-) diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 935bb1a88..9cc2b0f29 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -36,3 +36,19 @@ To inference, use: --epoch 1000 \ --tokens data/tokens.txt ``` + +## Quality vs speed + +If you feel that the trained model is slow at runtime, you can specify the +argument `--model-type` during training. Possible values are: + + - `low`, means **low** quality. The resulting model is very small in file size + and runs very fast. The following is a wave file generatd by a `low` model + + https://github.com/k2-fsa/icefall/assets/5284924/d5758c24-470d-40ee-b089-e57fcba81633 + + The text is "Ask not what your country can do for you; ask what you can do for your country." + + - `medium`, means **medium** quality. + - `high`, means **high** quality + diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 6055861e2..8d66d5b35 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -143,7 +143,7 @@ class OnnxModel(nn.Module): Return a tuple containing: - audio, generated wavform tensor, (B, T_wav) """ - audio, _, _ = self.model.inference( + audio, _, _ = self.model.generator.inference( text=tokens, text_lengths=tokens_lens, noise_scale=noise_scale, @@ -205,6 +205,11 @@ def export_model_onnx( }, ) + if model.model.spks is None: + num_speakers = 1 + else: + num_speakers = model.model.spks + meta_data = { "model_type": "vits", "version": "1", @@ -213,8 +218,8 @@ def export_model_onnx( "language": "English", "voice": "en-us", # Choose your language appropriately "has_espeak": 1, - "n_speakers": 1, - "sample_rate": 22050, # Must match the real sample rate + "n_speakers": num_speakers, + "sample_rate": model.model.sampling_rate, # Must match the real sample rate } logging.info(f"meta_data: {meta_data}") @@ -240,7 +245,6 @@ def main(): load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model = model.generator model.to("cpu") model.eval() diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index 4f46e8e6c..b3805fadb 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -54,6 +54,20 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--text", + type=str, + default="Ask not what your country can do for you; ask what you can do for your country.", + help="Text to generate speech for", + ) + + parser.add_argument( + "--output-filename", + type=str, + default="test_onnx.wav", + help="Filename to save the generated wave file.", + ) + return parser @@ -61,7 +75,7 @@ class OnnxModel: def __init__(self, model_filename: str): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 + session_opts.intra_op_num_threads = 1 self.session_opts = session_opts @@ -72,6 +86,9 @@ class OnnxModel: ) logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + metadata = self.model.get_modelmeta().custom_metadata_map + self.sample_rate = int(metadata["sample_rate"]) + def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: """ Args: @@ -101,13 +118,14 @@ class OnnxModel: def main(): args = get_parser().parse_args() + logging.info(vars(args)) tokenizer = Tokenizer(args.tokens) logging.info("About to create onnx model") model = OnnxModel(args.model_filename) - text = "I went there to see the land, the people and how their system works, end quote." + text = args.text tokens = tokenizer.texts_to_token_ids( [text], intersperse_blank=True, add_sos=True, add_eos=True ) @@ -115,8 +133,9 @@ def main(): tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) audio = model(tokens, tokens_lens) # (1, T') - torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) - logging.info("Saved to test_onnx.wav") + output_filename = args.output_filename + torchaudio.save(output_filename, audio, sample_rate=model.sample_rate) + logging.info(f"Saved to {output_filename}") if __name__ == "__main__": diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 9a5a9090e..8144ffe1e 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -18,7 +18,15 @@ import logging from typing import Dict, List import tacotron_cleaner.cleaners -from piper_phonemize import phonemize_espeak + +try: + from piper_phonemize import phonemize_espeak +except Exception as ex: + raise RuntimeError( + f"{ex}\nPlease follow instructions in " + "../prepare.sh to install piper-phonemize" + ) + from utils import intersperse