diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 80be5a315..935bb1a88 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -1,10 +1,10 @@ # Introduction -This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books. -A transcription is provided for each clip. +This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books. +A transcription is provided for each clip. Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours. -The texts were published between 1884 and 1964, and are in the public domain. +The texts were published between 1884 and 1964, and are in the public domain. The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain. The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/). @@ -35,4 +35,4 @@ To inference, use: --exp-dir vits/exp \ --epoch 1000 \ --tokens data/tokens.txt -``` \ No newline at end of file +``` diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index cbf27bd42..bded423ac 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -54,7 +54,7 @@ fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare LJSpeech manifest" # We assume that you have downloaded the LJSpeech corpus - # to $dl_dir/LJSpeech + # to $dl_dir/LJSpeech-1.1 mkdir -p data/manifests if [ ! -e data/manifests/.ljspeech.done ]; then lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 58b166368..6055861e2 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -25,9 +25,8 @@ Export the model to ONNX: --exp-dir vits/exp \ --tokens data/tokens.txt -It will generate two files inside vits/exp: +It will generate one file inside vits/exp: - vits-epoch-1000.onnx - - vits-epoch-1000.int8.onnx (quantizated model) See ./test_onnx.py for how to use the exported ONNX models. """ @@ -40,7 +39,6 @@ from typing import Dict, Tuple import onnx import torch import torch.nn as nn -from onnxruntime.quantization import QuantType, quantize_dynamic from tokenizer import Tokenizer from train import get_model, get_params @@ -75,6 +73,15 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--model-type", + type=str, + default="", + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + return parser @@ -240,7 +247,7 @@ def main(): model = OnnxModel(model=model) num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"generator parameters: {num_param}") + logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M") suffix = f"epoch-{params.epoch}" @@ -256,18 +263,6 @@ def main(): ) logging.info(f"Exported generator to {model_filename}") - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" - quantize_dynamic( - model_input=model_filename, - model_output=model_filename_int8, - weight_type=QuantType.QUInt8, - ) - if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py index 66c8cedb1..b9add9e82 100644 --- a/egs/ljspeech/TTS/vits/generator.py +++ b/egs/ljspeech/TTS/vits/generator.py @@ -189,7 +189,7 @@ class VITSGenerator(torch.nn.Module): self.upsample_factor = int(np.prod(decoder_upsample_scales)) self.spks = None if spks is not None and spks > 1: - assert global_channels > 0 + assert global_channels > 0, global_channels self.spks = spks self.global_emb = torch.nn.Embedding(spks, global_channels) self.spk_embed_dim = None diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 9e7c71c6d..40988adc4 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -72,6 +72,15 @@ def get_parser(): help="""Path to vocabulary.""", ) + parser.add_argument( + "--model-type", + type=str, + default="", + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + return parser @@ -94,6 +103,7 @@ def infer_dataset( tokenizer: Used to convert text to phonemes. """ + # Background worker save audios to disk. def _save_worker( batch_size: int, diff --git a/egs/ljspeech/TTS/vits/test_model.py b/egs/ljspeech/TTS/vits/test_model.py new file mode 100755 index 000000000..7b637f8f7 --- /dev/null +++ b/egs/ljspeech/TTS/vits/test_model.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tokenizer import Tokenizer +from train import get_model, get_params +from vits import VITS + + +def test_model_type(model_type): + tokens = "./data/tokens.txt" + + params = get_params() + + tokenizer = Tokenizer(tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_type = model_type + + model = get_model(params) + generator = model.generator + + num_param = sum([p.numel() for p in generator.parameters()]) + print( + f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M" + ) + + +def main(): + test_model_type("high") # 35.63 M + test_model_type("low") # 7.55 M + test_model_type("medium") # 23.61 M + test_model_type("") # 35.63 M + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py index fcbae7103..9b21ed9cb 100644 --- a/egs/ljspeech/TTS/vits/text_encoder.py +++ b/egs/ljspeech/TTS/vits/text_encoder.py @@ -92,9 +92,9 @@ class TextEncoder(torch.nn.Module): x_lengths (Tensor): Length tensor (B,). Returns: - Tensor: Encoded hidden representation (B, attention_dim, T_text). - Tensor: Projected mean tensor (B, attention_dim, T_text). - Tensor: Projected scale tensor (B, attention_dim, T_text). + Tensor: Encoded hidden representation (B, embed_dim, T_text). + Tensor: Projected mean tensor (B, embed_dim, T_text). + Tensor: Projected scale tensor (B, embed_dim, T_text). Tensor: Mask tensor for input tensor (B, 1, T_text). """ @@ -108,6 +108,7 @@ class TextEncoder(torch.nn.Module): # encoder assume the channel last (B, T_text, embed_dim) x = self.encoder(x, key_padding_mask=pad_mask) + # Note: attention_dim == embed_dim # convert the channel first (B, embed_dim, T_text) x = x.transpose(1, 2) diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 6589b75ff..767689b6c 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -153,6 +153,15 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--model-type", + type=str, + default="", + help="""If not empty, valid values are: low, medium, high. + It controls the model size. low -> runs faster. + """, + ) + return parser @@ -189,15 +198,6 @@ def get_params() -> AttributeDict: - feature_dim: The model input dim. It has to match the one used in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. """ params = AttributeDict( { @@ -278,6 +278,7 @@ def get_model(params: AttributeDict) -> nn.Module: vocab_size=params.vocab_size, feature_dim=params.feature_dim, sampling_rate=params.sampling_rate, + model_type=params.model_type, mel_loss_params=mel_loss_params, lambda_adv=params.lambda_adv, lambda_mel=params.lambda_mel, @@ -363,7 +364,7 @@ def train_one_epoch( model.train() device = model.device if isinstance(model, DDP) else next(model.parameters()).device - # used to summary the stats over iterations in one epoch + # used to track the stats over iterations in one epoch tot_loss = MetricsTracker() saved_bad_model = False diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index b4f0c21e6..43d8ce6a3 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -5,6 +5,7 @@ """VITS module for GAN-TTS task.""" +import copy from typing import Any, Dict, Optional, Tuple import torch @@ -38,6 +39,36 @@ AVAILABLE_DISCRIMINATORS = { "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA } +LOW_CONFIG = { + "hidden_channels": 96, + "decoder_upsample_scales": (8, 8, 4), + "decoder_channels": 256, + "decoder_upsample_kernel_sizes": (16, 16, 8), + "decoder_resblock_kernel_sizes": (3, 5, 7), + "decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)), + "text_encoder_cnn_module_kernel": 3, +} + +MEDIUM_CONFIG = { + "hidden_channels": 192, + "decoder_upsample_scales": (8, 8, 4), + "decoder_channels": 256, + "decoder_upsample_kernel_sizes": (16, 16, 8), + "decoder_resblock_kernel_sizes": (3, 5, 7), + "decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)), + "text_encoder_cnn_module_kernel": 3, +} + +HIGH_CONFIG = { + "hidden_channels": 192, + "decoder_upsample_scales": (8, 8, 2, 2), + "decoder_channels": 512, + "decoder_upsample_kernel_sizes": (16, 16, 4, 4), + "decoder_resblock_kernel_sizes": (3, 7, 11), + "decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)), + "text_encoder_cnn_module_kernel": 3, +} + class VITS(nn.Module): """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`""" @@ -49,6 +80,7 @@ class VITS(nn.Module): feature_dim: int = 513, sampling_rate: int = 22050, generator_type: str = "vits_generator", + model_type: str = "", generator_params: Dict[str, Any] = { "hidden_channels": 192, "spks": None, @@ -155,12 +187,13 @@ class VITS(nn.Module): """Initialize VITS module. Args: - idim (int): Input vocabrary size. + idim (int): Input vocabulary size. odim (int): Acoustic feature dimension. The actual output channels will be 1 since VITS is the end-to-end text-to-wave model but for the compatibility odim is used to indicate the acoustic feature dimension. sampling_rate (int): Sampling rate, not used for the training but it will be referred in saving waveform during the inference. + model_type (str): If not empty, must be one of: low, medium, high generator_type (str): Generator type. generator_params (Dict[str, Any]): Parameter dict for generator. discriminator_type (str): Discriminator type. @@ -181,6 +214,24 @@ class VITS(nn.Module): """ super().__init__() + generator_params = copy.deepcopy(generator_params) + discriminator_params = copy.deepcopy(discriminator_params) + generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params) + discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params) + feat_match_loss_params = copy.deepcopy(feat_match_loss_params) + mel_loss_params = copy.deepcopy(mel_loss_params) + + if model_type != "": + assert model_type in ("low", "medium", "high"), model_type + if model_type == "low": + generator_params.update(LOW_CONFIG) + elif model_type == "medium": + generator_params.update(MEDIUM_CONFIG) + elif model_type == "high": + generator_params.update(HIGH_CONFIG) + else: + raise ValueError(f"Unknown model_type: ${model_type}") + # define modules generator_class = AVAILABLE_GENERATERS[generator_type] if generator_type == "vits_generator":