Support different tts model types.

low (quality) -> runs faster. high (quality) -> runs slower
This commit is contained in:
Fangjun Kuang 2024-03-11 12:42:43 +08:00
parent ae61bd4090
commit b33d3820db
9 changed files with 145 additions and 36 deletions

View File

@ -1,10 +1,10 @@
# Introduction # Introduction
This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books. This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
A transcription is provided for each clip. A transcription is provided for each clip.
Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours. Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours.
The texts were published between 1884 and 1964, and are in the public domain. The texts were published between 1884 and 1964, and are in the public domain.
The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain. The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain.
The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/). The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/).
@ -35,4 +35,4 @@ To inference, use:
--exp-dir vits/exp \ --exp-dir vits/exp \
--epoch 1000 \ --epoch 1000 \
--tokens data/tokens.txt --tokens data/tokens.txt
``` ```

View File

@ -54,7 +54,7 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LJSpeech manifest" log "Stage 1: Prepare LJSpeech manifest"
# We assume that you have downloaded the LJSpeech corpus # We assume that you have downloaded the LJSpeech corpus
# to $dl_dir/LJSpeech # to $dl_dir/LJSpeech-1.1
mkdir -p data/manifests mkdir -p data/manifests
if [ ! -e data/manifests/.ljspeech.done ]; then if [ ! -e data/manifests/.ljspeech.done ]; then
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests

View File

@ -25,9 +25,8 @@ Export the model to ONNX:
--exp-dir vits/exp \ --exp-dir vits/exp \
--tokens data/tokens.txt --tokens data/tokens.txt
It will generate two files inside vits/exp: It will generate one file inside vits/exp:
- vits-epoch-1000.onnx - vits-epoch-1000.onnx
- vits-epoch-1000.int8.onnx (quantizated model)
See ./test_onnx.py for how to use the exported ONNX models. See ./test_onnx.py for how to use the exported ONNX models.
""" """
@ -40,7 +39,6 @@ from typing import Dict, Tuple
import onnx import onnx
import torch import torch
import torch.nn as nn import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic
from tokenizer import Tokenizer from tokenizer import Tokenizer
from train import get_model, get_params from train import get_model, get_params
@ -75,6 +73,15 @@ def get_parser():
help="""Path to vocabulary.""", help="""Path to vocabulary.""",
) )
parser.add_argument(
"--model-type",
type=str,
default="",
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser return parser
@ -240,7 +247,7 @@ def main():
model = OnnxModel(model=model) model = OnnxModel(model=model)
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"generator parameters: {num_param}") logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
suffix = f"epoch-{params.epoch}" suffix = f"epoch-{params.epoch}"
@ -256,18 +263,6 @@ def main():
) )
logging.info(f"Exported generator to {model_filename}") logging.info(f"Exported generator to {model_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
weight_type=QuantType.QUInt8,
)
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

View File

@ -189,7 +189,7 @@ class VITSGenerator(torch.nn.Module):
self.upsample_factor = int(np.prod(decoder_upsample_scales)) self.upsample_factor = int(np.prod(decoder_upsample_scales))
self.spks = None self.spks = None
if spks is not None and spks > 1: if spks is not None and spks > 1:
assert global_channels > 0 assert global_channels > 0, global_channels
self.spks = spks self.spks = spks
self.global_emb = torch.nn.Embedding(spks, global_channels) self.global_emb = torch.nn.Embedding(spks, global_channels)
self.spk_embed_dim = None self.spk_embed_dim = None

View File

@ -72,6 +72,15 @@ def get_parser():
help="""Path to vocabulary.""", help="""Path to vocabulary.""",
) )
parser.add_argument(
"--model-type",
type=str,
default="",
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser return parser
@ -94,6 +103,7 @@ def infer_dataset(
tokenizer: tokenizer:
Used to convert text to phonemes. Used to convert text to phonemes.
""" """
# Background worker save audios to disk. # Background worker save audios to disk.
def _save_worker( def _save_worker(
batch_size: int, batch_size: int,

View File

@ -0,0 +1,51 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tokenizer import Tokenizer
from train import get_model, get_params
from vits import VITS
def test_model_type(model_type):
tokens = "./data/tokens.txt"
params = get_params()
tokenizer = Tokenizer(tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_type = model_type
model = get_model(params)
generator = model.generator
num_param = sum([p.numel() for p in generator.parameters()])
print(
f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M"
)
def main():
test_model_type("high") # 35.63 M
test_model_type("low") # 7.55 M
test_model_type("medium") # 23.61 M
test_model_type("") # 35.63 M
if __name__ == "__main__":
main()

View File

@ -92,9 +92,9 @@ class TextEncoder(torch.nn.Module):
x_lengths (Tensor): Length tensor (B,). x_lengths (Tensor): Length tensor (B,).
Returns: Returns:
Tensor: Encoded hidden representation (B, attention_dim, T_text). Tensor: Encoded hidden representation (B, embed_dim, T_text).
Tensor: Projected mean tensor (B, attention_dim, T_text). Tensor: Projected mean tensor (B, embed_dim, T_text).
Tensor: Projected scale tensor (B, attention_dim, T_text). Tensor: Projected scale tensor (B, embed_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text). Tensor: Mask tensor for input tensor (B, 1, T_text).
""" """
@ -108,6 +108,7 @@ class TextEncoder(torch.nn.Module):
# encoder assume the channel last (B, T_text, embed_dim) # encoder assume the channel last (B, T_text, embed_dim)
x = self.encoder(x, key_padding_mask=pad_mask) x = self.encoder(x, key_padding_mask=pad_mask)
# Note: attention_dim == embed_dim
# convert the channel first (B, embed_dim, T_text) # convert the channel first (B, embed_dim, T_text)
x = x.transpose(1, 2) x = x.transpose(1, 2)

View File

@ -153,6 +153,15 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--model-type",
type=str,
default="",
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser return parser
@ -189,15 +198,6 @@ def get_params() -> AttributeDict:
- feature_dim: The model input dim. It has to match the one used - feature_dim: The model input dim. It has to match the one used
in computing features. in computing features.
- subsampling_factor: The subsampling factor for the model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warmup period that dictates the decay of the
scale on "simple" (un-pruned) loss.
""" """
params = AttributeDict( params = AttributeDict(
{ {
@ -278,6 +278,7 @@ def get_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
feature_dim=params.feature_dim, feature_dim=params.feature_dim,
sampling_rate=params.sampling_rate, sampling_rate=params.sampling_rate,
model_type=params.model_type,
mel_loss_params=mel_loss_params, mel_loss_params=mel_loss_params,
lambda_adv=params.lambda_adv, lambda_adv=params.lambda_adv,
lambda_mel=params.lambda_mel, lambda_mel=params.lambda_mel,
@ -363,7 +364,7 @@ def train_one_epoch(
model.train() model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations in one epoch # used to track the stats over iterations in one epoch
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
saved_bad_model = False saved_bad_model = False

View File

@ -5,6 +5,7 @@
"""VITS module for GAN-TTS task.""" """VITS module for GAN-TTS task."""
import copy
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import torch import torch
@ -38,6 +39,36 @@ AVAILABLE_DISCRIMINATORS = {
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
} }
LOW_CONFIG = {
"hidden_channels": 96,
"decoder_upsample_scales": (8, 8, 4),
"decoder_channels": 256,
"decoder_upsample_kernel_sizes": (16, 16, 8),
"decoder_resblock_kernel_sizes": (3, 5, 7),
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
"text_encoder_cnn_module_kernel": 3,
}
MEDIUM_CONFIG = {
"hidden_channels": 192,
"decoder_upsample_scales": (8, 8, 4),
"decoder_channels": 256,
"decoder_upsample_kernel_sizes": (16, 16, 8),
"decoder_resblock_kernel_sizes": (3, 5, 7),
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
"text_encoder_cnn_module_kernel": 3,
}
HIGH_CONFIG = {
"hidden_channels": 192,
"decoder_upsample_scales": (8, 8, 2, 2),
"decoder_channels": 512,
"decoder_upsample_kernel_sizes": (16, 16, 4, 4),
"decoder_resblock_kernel_sizes": (3, 7, 11),
"decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
"text_encoder_cnn_module_kernel": 3,
}
class VITS(nn.Module): class VITS(nn.Module):
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`""" """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
@ -49,6 +80,7 @@ class VITS(nn.Module):
feature_dim: int = 513, feature_dim: int = 513,
sampling_rate: int = 22050, sampling_rate: int = 22050,
generator_type: str = "vits_generator", generator_type: str = "vits_generator",
model_type: str = "",
generator_params: Dict[str, Any] = { generator_params: Dict[str, Any] = {
"hidden_channels": 192, "hidden_channels": 192,
"spks": None, "spks": None,
@ -155,12 +187,13 @@ class VITS(nn.Module):
"""Initialize VITS module. """Initialize VITS module.
Args: Args:
idim (int): Input vocabrary size. idim (int): Input vocabulary size.
odim (int): Acoustic feature dimension. The actual output channels will odim (int): Acoustic feature dimension. The actual output channels will
be 1 since VITS is the end-to-end text-to-wave model but for the be 1 since VITS is the end-to-end text-to-wave model but for the
compatibility odim is used to indicate the acoustic feature dimension. compatibility odim is used to indicate the acoustic feature dimension.
sampling_rate (int): Sampling rate, not used for the training but it will sampling_rate (int): Sampling rate, not used for the training but it will
be referred in saving waveform during the inference. be referred in saving waveform during the inference.
model_type (str): If not empty, must be one of: low, medium, high
generator_type (str): Generator type. generator_type (str): Generator type.
generator_params (Dict[str, Any]): Parameter dict for generator. generator_params (Dict[str, Any]): Parameter dict for generator.
discriminator_type (str): Discriminator type. discriminator_type (str): Discriminator type.
@ -181,6 +214,24 @@ class VITS(nn.Module):
""" """
super().__init__() super().__init__()
generator_params = copy.deepcopy(generator_params)
discriminator_params = copy.deepcopy(discriminator_params)
generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params)
discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params)
feat_match_loss_params = copy.deepcopy(feat_match_loss_params)
mel_loss_params = copy.deepcopy(mel_loss_params)
if model_type != "":
assert model_type in ("low", "medium", "high"), model_type
if model_type == "low":
generator_params.update(LOW_CONFIG)
elif model_type == "medium":
generator_params.update(MEDIUM_CONFIG)
elif model_type == "high":
generator_params.update(HIGH_CONFIG)
else:
raise ValueError(f"Unknown model_type: ${model_type}")
# define modules # define modules
generator_class = AVAILABLE_GENERATERS[generator_type] generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator": if generator_type == "vits_generator":