Support different tts model types.

low (quality) -> runs faster. high (quality) -> runs slower
This commit is contained in:
Fangjun Kuang 2024-03-11 12:42:43 +08:00
parent ae61bd4090
commit b33d3820db
9 changed files with 145 additions and 36 deletions

View File

@ -1,10 +1,10 @@
# Introduction
This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
A transcription is provided for each clip.
This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
A transcription is provided for each clip.
Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours.
The texts were published between 1884 and 1964, and are in the public domain.
The texts were published between 1884 and 1964, and are in the public domain.
The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain.
The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/).
@ -35,4 +35,4 @@ To inference, use:
--exp-dir vits/exp \
--epoch 1000 \
--tokens data/tokens.txt
```
```

View File

@ -54,7 +54,7 @@ fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LJSpeech manifest"
# We assume that you have downloaded the LJSpeech corpus
# to $dl_dir/LJSpeech
# to $dl_dir/LJSpeech-1.1
mkdir -p data/manifests
if [ ! -e data/manifests/.ljspeech.done ]; then
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests

View File

@ -25,9 +25,8 @@ Export the model to ONNX:
--exp-dir vits/exp \
--tokens data/tokens.txt
It will generate two files inside vits/exp:
It will generate one file inside vits/exp:
- vits-epoch-1000.onnx
- vits-epoch-1000.int8.onnx (quantizated model)
See ./test_onnx.py for how to use the exported ONNX models.
"""
@ -40,7 +39,6 @@ from typing import Dict, Tuple
import onnx
import torch
import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic
from tokenizer import Tokenizer
from train import get_model, get_params
@ -75,6 +73,15 @@ def get_parser():
help="""Path to vocabulary.""",
)
parser.add_argument(
"--model-type",
type=str,
default="",
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser
@ -240,7 +247,7 @@ def main():
model = OnnxModel(model=model)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"generator parameters: {num_param}")
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
suffix = f"epoch-{params.epoch}"
@ -256,18 +263,6 @@ def main():
)
logging.info(f"Exported generator to {model_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
weight_type=QuantType.QUInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

View File

@ -189,7 +189,7 @@ class VITSGenerator(torch.nn.Module):
self.upsample_factor = int(np.prod(decoder_upsample_scales))
self.spks = None
if spks is not None and spks > 1:
assert global_channels > 0
assert global_channels > 0, global_channels
self.spks = spks
self.global_emb = torch.nn.Embedding(spks, global_channels)
self.spk_embed_dim = None

View File

@ -72,6 +72,15 @@ def get_parser():
help="""Path to vocabulary.""",
)
parser.add_argument(
"--model-type",
type=str,
default="",
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser
@ -94,6 +103,7 @@ def infer_dataset(
tokenizer:
Used to convert text to phonemes.
"""
# Background worker save audios to disk.
def _save_worker(
batch_size: int,

View File

@ -0,0 +1,51 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tokenizer import Tokenizer
from train import get_model, get_params
from vits import VITS
def test_model_type(model_type):
tokens = "./data/tokens.txt"
params = get_params()
tokenizer = Tokenizer(tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_type = model_type
model = get_model(params)
generator = model.generator
num_param = sum([p.numel() for p in generator.parameters()])
print(
f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M"
)
def main():
test_model_type("high") # 35.63 M
test_model_type("low") # 7.55 M
test_model_type("medium") # 23.61 M
test_model_type("") # 35.63 M
if __name__ == "__main__":
main()

View File

@ -92,9 +92,9 @@ class TextEncoder(torch.nn.Module):
x_lengths (Tensor): Length tensor (B,).
Returns:
Tensor: Encoded hidden representation (B, attention_dim, T_text).
Tensor: Projected mean tensor (B, attention_dim, T_text).
Tensor: Projected scale tensor (B, attention_dim, T_text).
Tensor: Encoded hidden representation (B, embed_dim, T_text).
Tensor: Projected mean tensor (B, embed_dim, T_text).
Tensor: Projected scale tensor (B, embed_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text).
"""
@ -108,6 +108,7 @@ class TextEncoder(torch.nn.Module):
# encoder assume the channel last (B, T_text, embed_dim)
x = self.encoder(x, key_padding_mask=pad_mask)
# Note: attention_dim == embed_dim
# convert the channel first (B, embed_dim, T_text)
x = x.transpose(1, 2)

View File

@ -153,6 +153,15 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--model-type",
type=str,
default="",
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser
@ -189,15 +198,6 @@ def get_params() -> AttributeDict:
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warmup period that dictates the decay of the
scale on "simple" (un-pruned) loss.
"""
params = AttributeDict(
{
@ -278,6 +278,7 @@ def get_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size,
feature_dim=params.feature_dim,
sampling_rate=params.sampling_rate,
model_type=params.model_type,
mel_loss_params=mel_loss_params,
lambda_adv=params.lambda_adv,
lambda_mel=params.lambda_mel,
@ -363,7 +364,7 @@ def train_one_epoch(
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations in one epoch
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False

View File

@ -5,6 +5,7 @@
"""VITS module for GAN-TTS task."""
import copy
from typing import Any, Dict, Optional, Tuple
import torch
@ -38,6 +39,36 @@ AVAILABLE_DISCRIMINATORS = {
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
}
LOW_CONFIG = {
"hidden_channels": 96,
"decoder_upsample_scales": (8, 8, 4),
"decoder_channels": 256,
"decoder_upsample_kernel_sizes": (16, 16, 8),
"decoder_resblock_kernel_sizes": (3, 5, 7),
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
"text_encoder_cnn_module_kernel": 3,
}
MEDIUM_CONFIG = {
"hidden_channels": 192,
"decoder_upsample_scales": (8, 8, 4),
"decoder_channels": 256,
"decoder_upsample_kernel_sizes": (16, 16, 8),
"decoder_resblock_kernel_sizes": (3, 5, 7),
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
"text_encoder_cnn_module_kernel": 3,
}
HIGH_CONFIG = {
"hidden_channels": 192,
"decoder_upsample_scales": (8, 8, 2, 2),
"decoder_channels": 512,
"decoder_upsample_kernel_sizes": (16, 16, 4, 4),
"decoder_resblock_kernel_sizes": (3, 7, 11),
"decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
"text_encoder_cnn_module_kernel": 3,
}
class VITS(nn.Module):
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
@ -49,6 +80,7 @@ class VITS(nn.Module):
feature_dim: int = 513,
sampling_rate: int = 22050,
generator_type: str = "vits_generator",
model_type: str = "",
generator_params: Dict[str, Any] = {
"hidden_channels": 192,
"spks": None,
@ -155,12 +187,13 @@ class VITS(nn.Module):
"""Initialize VITS module.
Args:
idim (int): Input vocabrary size.
idim (int): Input vocabulary size.
odim (int): Acoustic feature dimension. The actual output channels will
be 1 since VITS is the end-to-end text-to-wave model but for the
compatibility odim is used to indicate the acoustic feature dimension.
sampling_rate (int): Sampling rate, not used for the training but it will
be referred in saving waveform during the inference.
model_type (str): If not empty, must be one of: low, medium, high
generator_type (str): Generator type.
generator_params (Dict[str, Any]): Parameter dict for generator.
discriminator_type (str): Discriminator type.
@ -181,6 +214,24 @@ class VITS(nn.Module):
"""
super().__init__()
generator_params = copy.deepcopy(generator_params)
discriminator_params = copy.deepcopy(discriminator_params)
generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params)
discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params)
feat_match_loss_params = copy.deepcopy(feat_match_loss_params)
mel_loss_params = copy.deepcopy(mel_loss_params)
if model_type != "":
assert model_type in ("low", "medium", "high"), model_type
if model_type == "low":
generator_params.update(LOW_CONFIG)
elif model_type == "medium":
generator_params.update(MEDIUM_CONFIG)
elif model_type == "high":
generator_params.update(HIGH_CONFIG)
else:
raise ValueError(f"Unknown model_type: ${model_type}")
# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator":