Support different tts model types. (#1541)

This commit is contained in:
Fangjun Kuang 2024-03-12 22:29:21 +08:00 committed by GitHub
parent 959906e9dc
commit 81f518ea7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 265 additions and 49 deletions

View File

@ -56,7 +56,8 @@ Training
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--tokens data/tokens.txt
--tokens data/tokens.txt \
--model-type high \
--max-duration 500
.. note::
@ -64,6 +65,11 @@ Training
You can adjust the hyper-parameters to control the size of the VITS model and
the training configurations. For more details, please run ``./vits/train.py --help``.
.. warning::
If you want a model that runs faster on CPU, please use ``--model-type low``
or ``--model-type medium``.
.. note::
The training can take a long time (usually a couple of days).
@ -95,8 +101,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
Export models
-------------
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
``vits-epoch-*.onnx``.
.. code-block:: bash
@ -120,4 +126,7 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models
by visiting the following link:
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
- ``--model-type=high``: `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
- ``--model-type=medium``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>`_
- ``--model-type=low``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>`_

View File

@ -36,3 +36,68 @@ To inference, use:
--epoch 1000 \
--tokens data/tokens.txt
```
## Quality vs speed
If you feel that the trained model is slow at runtime, you can specify the
argument `--model-type` during training. Possible values are:
- `low`, means **low** quality. The resulting model is very small in file size
and runs very fast. The following is a wave file generatd by a `low` quality model
https://github.com/k2-fsa/icefall/assets/5284924/d5758c24-470d-40ee-b089-e57fcba81633
The text is `Ask not what your country can do for you; ask what you can do for your country.`
The exported onnx model has a file size of ``26.8 MB`` (float32).
- `medium`, means **medium** quality.
The following is a wave file generatd by a `medium` quality model
https://github.com/k2-fsa/icefall/assets/5284924/b199d960-3665-4d0d-9ae9-a1bb69cbc8ac
The text is `Ask not what your country can do for you; ask what you can do for your country.`
The exported onnx model has a file size of ``70.9 MB`` (float32).
- `high`, means **high** quality. This is the default value.
The following is a wave file generatd by a `high` quality model
https://github.com/k2-fsa/icefall/assets/5284924/b39f3048-73a6-4267-bf95-df5abfdb28fc
The text is `Ask not what your country can do for you; ask what you can do for your country.`
The exported onnx model has a file size of ``113 MB`` (float32).
A pre-trained `low` model trained using 4xV100 32GB GPU with the following command can be found at
<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
./vits/train.py \
--world-size 4 \
--num-epochs 1601 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--model-type low \
--max-duration 800
```
A pre-trained `medium` model trained using 4xV100 32GB GPU with the following command can be found at
<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>
```bash
export CUDA_VISIBLE_DEVICES=4,5,6,7
./vits/train.py \
--world-size 4 \
--num-epochs 1000 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp-medium \
--model-type medium \
--max-duration 500
# (Note it is killed after `epoch-820.pt`)
```

View File

@ -54,7 +54,7 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LJSpeech manifest"
# We assume that you have downloaded the LJSpeech corpus
# to $dl_dir/LJSpeech
# to $dl_dir/LJSpeech-1.1
mkdir -p data/manifests
if [ ! -e data/manifests/.ljspeech.done ]; then
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests

View File

@ -25,9 +25,8 @@ Export the model to ONNX:
--exp-dir vits/exp \
--tokens data/tokens.txt
It will generate two files inside vits/exp:
It will generate one file inside vits/exp:
- vits-epoch-1000.onnx
- vits-epoch-1000.int8.onnx (quantizated model)
See ./test_onnx.py for how to use the exported ONNX models.
"""
@ -40,7 +39,6 @@ from typing import Dict, Tuple
import onnx
import torch
import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic
from tokenizer import Tokenizer
from train import get_model, get_params
@ -75,6 +73,16 @@ def get_parser():
help="""Path to vocabulary.""",
)
parser.add_argument(
"--model-type",
type=str,
default="high",
choices=["low", "medium", "high"],
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser
@ -136,7 +144,7 @@ class OnnxModel(nn.Module):
Return a tuple containing:
- audio, generated wavform tensor, (B, T_wav)
"""
audio, _, _ = self.model.inference(
audio, _, _ = self.model.generator.inference(
text=tokens,
text_lengths=tokens_lens,
noise_scale=noise_scale,
@ -198,6 +206,11 @@ def export_model_onnx(
},
)
if model.model.spks is None:
num_speakers = 1
else:
num_speakers = model.model.spks
meta_data = {
"model_type": "vits",
"version": "1",
@ -206,8 +219,8 @@ def export_model_onnx(
"language": "English",
"voice": "en-us", # Choose your language appropriately
"has_espeak": 1,
"n_speakers": 1,
"sample_rate": 22050, # Must match the real sample rate
"n_speakers": num_speakers,
"sample_rate": model.model.sampling_rate, # Must match the real sample rate
}
logging.info(f"meta_data: {meta_data}")
@ -233,14 +246,13 @@ def main():
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model = model.generator
model.to("cpu")
model.eval()
model = OnnxModel(model=model)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"generator parameters: {num_param}")
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
suffix = f"epoch-{params.epoch}"
@ -256,18 +268,6 @@ def main():
)
logging.info(f"Exported generator to {model_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
weight_type=QuantType.QUInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

View File

@ -189,7 +189,7 @@ class VITSGenerator(torch.nn.Module):
self.upsample_factor = int(np.prod(decoder_upsample_scales))
self.spks = None
if spks is not None and spks > 1:
assert global_channels > 0
assert global_channels > 0, global_channels
self.spks = spks
self.global_emb = torch.nn.Embedding(spks, global_channels)
self.spk_embed_dim = None

View File

@ -72,6 +72,16 @@ def get_parser():
help="""Path to vocabulary.""",
)
parser.add_argument(
"--model-type",
type=str,
default="high",
choices=["low", "medium", "high"],
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser
@ -94,6 +104,7 @@ def infer_dataset(
tokenizer:
Used to convert text to phonemes.
"""
# Background worker save audios to disk.
def _save_worker(
batch_size: int,

View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tokenizer import Tokenizer
from train import get_model, get_params
from vits import VITS
def test_model_type(model_type):
tokens = "./data/tokens.txt"
params = get_params()
tokenizer = Tokenizer(tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_type = model_type
model = get_model(params)
generator = model.generator
num_param = sum([p.numel() for p in generator.parameters()])
print(
f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M"
)
def main():
test_model_type("high") # 35.63 M
test_model_type("low") # 7.55 M
test_model_type("medium") # 23.61 M
if __name__ == "__main__":
main()

View File

@ -54,6 +54,20 @@ def get_parser():
help="""Path to vocabulary.""",
)
parser.add_argument(
"--text",
type=str,
default="Ask not what your country can do for you; ask what you can do for your country.",
help="Text to generate speech for",
)
parser.add_argument(
"--output-filename",
type=str,
default="test_onnx.wav",
help="Filename to save the generated wave file.",
)
return parser
@ -61,7 +75,7 @@ class OnnxModel:
def __init__(self, model_filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
@ -72,6 +86,9 @@ class OnnxModel:
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
metadata = self.model.get_modelmeta().custom_metadata_map
self.sample_rate = int(metadata["sample_rate"])
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
@ -101,13 +118,14 @@ class OnnxModel:
def main():
args = get_parser().parse_args()
logging.info(vars(args))
tokenizer = Tokenizer(args.tokens)
logging.info("About to create onnx model")
model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote."
text = args.text
tokens = tokenizer.texts_to_token_ids(
[text], intersperse_blank=True, add_sos=True, add_eos=True
)
@ -115,8 +133,9 @@ def main():
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T')
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
logging.info("Saved to test_onnx.wav")
output_filename = args.output_filename
torchaudio.save(output_filename, audio, sample_rate=model.sample_rate)
logging.info(f"Saved to {output_filename}")
if __name__ == "__main__":

View File

@ -92,9 +92,9 @@ class TextEncoder(torch.nn.Module):
x_lengths (Tensor): Length tensor (B,).
Returns:
Tensor: Encoded hidden representation (B, attention_dim, T_text).
Tensor: Projected mean tensor (B, attention_dim, T_text).
Tensor: Projected scale tensor (B, attention_dim, T_text).
Tensor: Encoded hidden representation (B, embed_dim, T_text).
Tensor: Projected mean tensor (B, embed_dim, T_text).
Tensor: Projected scale tensor (B, embed_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text).
"""
@ -108,6 +108,7 @@ class TextEncoder(torch.nn.Module):
# encoder assume the channel last (B, T_text, embed_dim)
x = self.encoder(x, key_padding_mask=pad_mask)
# Note: attention_dim == embed_dim
# convert the channel first (B, embed_dim, T_text)
x = x.transpose(1, 2)

View File

@ -18,7 +18,15 @@ import logging
from typing import Dict, List
import tacotron_cleaner.cleaners
try:
from piper_phonemize import phonemize_espeak
except Exception as ex:
raise RuntimeError(
f"{ex}\nPlease follow instructions in "
"../prepare.sh to install piper-phonemize"
)
from utils import intersperse

View File

@ -153,6 +153,16 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--model-type",
type=str,
default="high",
choices=["low", "medium", "high"],
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser
@ -189,15 +199,6 @@ def get_params() -> AttributeDict:
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warmup period that dictates the decay of the
scale on "simple" (un-pruned) loss.
"""
params = AttributeDict(
{
@ -278,6 +279,7 @@ def get_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size,
feature_dim=params.feature_dim,
sampling_rate=params.sampling_rate,
model_type=params.model_type,
mel_loss_params=mel_loss_params,
lambda_adv=params.lambda_adv,
lambda_mel=params.lambda_mel,
@ -363,7 +365,7 @@ def train_one_epoch(
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations in one epoch
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False

View File

@ -5,6 +5,7 @@
"""VITS module for GAN-TTS task."""
import copy
from typing import Any, Dict, Optional, Tuple
import torch
@ -38,6 +39,36 @@ AVAILABLE_DISCRIMINATORS = {
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
}
LOW_CONFIG = {
"hidden_channels": 96,
"decoder_upsample_scales": (8, 8, 4),
"decoder_channels": 256,
"decoder_upsample_kernel_sizes": (16, 16, 8),
"decoder_resblock_kernel_sizes": (3, 5, 7),
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
"text_encoder_cnn_module_kernel": 3,
}
MEDIUM_CONFIG = {
"hidden_channels": 192,
"decoder_upsample_scales": (8, 8, 4),
"decoder_channels": 256,
"decoder_upsample_kernel_sizes": (16, 16, 8),
"decoder_resblock_kernel_sizes": (3, 5, 7),
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
"text_encoder_cnn_module_kernel": 3,
}
HIGH_CONFIG = {
"hidden_channels": 192,
"decoder_upsample_scales": (8, 8, 2, 2),
"decoder_channels": 512,
"decoder_upsample_kernel_sizes": (16, 16, 4, 4),
"decoder_resblock_kernel_sizes": (3, 7, 11),
"decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
"text_encoder_cnn_module_kernel": 5,
}
class VITS(nn.Module):
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
@ -49,6 +80,7 @@ class VITS(nn.Module):
feature_dim: int = 513,
sampling_rate: int = 22050,
generator_type: str = "vits_generator",
model_type: str = "",
generator_params: Dict[str, Any] = {
"hidden_channels": 192,
"spks": None,
@ -155,12 +187,13 @@ class VITS(nn.Module):
"""Initialize VITS module.
Args:
idim (int): Input vocabrary size.
idim (int): Input vocabulary size.
odim (int): Acoustic feature dimension. The actual output channels will
be 1 since VITS is the end-to-end text-to-wave model but for the
compatibility odim is used to indicate the acoustic feature dimension.
sampling_rate (int): Sampling rate, not used for the training but it will
be referred in saving waveform during the inference.
model_type (str): If not empty, must be one of: low, medium, high
generator_type (str): Generator type.
generator_params (Dict[str, Any]): Parameter dict for generator.
discriminator_type (str): Discriminator type.
@ -181,6 +214,24 @@ class VITS(nn.Module):
"""
super().__init__()
generator_params = copy.deepcopy(generator_params)
discriminator_params = copy.deepcopy(discriminator_params)
generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params)
discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params)
feat_match_loss_params = copy.deepcopy(feat_match_loss_params)
mel_loss_params = copy.deepcopy(mel_loss_params)
if model_type != "":
assert model_type in ("low", "medium", "high"), model_type
if model_type == "low":
generator_params.update(LOW_CONFIG)
elif model_type == "medium":
generator_params.update(MEDIUM_CONFIG)
elif model_type == "high":
generator_params.update(HIGH_CONFIG)
else:
raise ValueError(f"Unknown model_type: ${model_type}")
# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator":