mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Support different tts model types.
low (quality) -> runs faster. high (quality) -> runs slower
This commit is contained in:
parent
ae61bd4090
commit
b33d3820db
@ -54,7 +54,7 @@ fi
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare LJSpeech manifest"
|
||||
# We assume that you have downloaded the LJSpeech corpus
|
||||
# to $dl_dir/LJSpeech
|
||||
# to $dl_dir/LJSpeech-1.1
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.ljspeech.done ]; then
|
||||
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
|
||||
|
@ -25,9 +25,8 @@ Export the model to ONNX:
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt
|
||||
|
||||
It will generate two files inside vits/exp:
|
||||
It will generate one file inside vits/exp:
|
||||
- vits-epoch-1000.onnx
|
||||
- vits-epoch-1000.int8.onnx (quantizated model)
|
||||
|
||||
See ./test_onnx.py for how to use the exported ONNX models.
|
||||
"""
|
||||
@ -40,7 +39,6 @@ from typing import Dict, Tuple
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
|
||||
@ -75,6 +73,15 @@ def get_parser():
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="",
|
||||
help="""If not empty, valid values are: low, medium, high.
|
||||
It controls the model size. low -> runs faster.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -240,7 +247,7 @@ def main():
|
||||
model = OnnxModel(model=model)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"generator parameters: {num_param}")
|
||||
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
|
||||
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
@ -256,18 +263,6 @@ def main():
|
||||
)
|
||||
logging.info(f"Exported generator to {model_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=model_filename,
|
||||
model_output=model_filename_int8,
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
@ -189,7 +189,7 @@ class VITSGenerator(torch.nn.Module):
|
||||
self.upsample_factor = int(np.prod(decoder_upsample_scales))
|
||||
self.spks = None
|
||||
if spks is not None and spks > 1:
|
||||
assert global_channels > 0
|
||||
assert global_channels > 0, global_channels
|
||||
self.spks = spks
|
||||
self.global_emb = torch.nn.Embedding(spks, global_channels)
|
||||
self.spk_embed_dim = None
|
||||
|
@ -72,6 +72,15 @@ def get_parser():
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="",
|
||||
help="""If not empty, valid values are: low, medium, high.
|
||||
It controls the model size. low -> runs faster.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -94,6 +103,7 @@ def infer_dataset(
|
||||
tokenizer:
|
||||
Used to convert text to phonemes.
|
||||
"""
|
||||
|
||||
# Background worker save audios to disk.
|
||||
def _save_worker(
|
||||
batch_size: int,
|
||||
|
51
egs/ljspeech/TTS/vits/test_model.py
Executable file
51
egs/ljspeech/TTS/vits/test_model.py
Executable file
@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
from vits import VITS
|
||||
|
||||
|
||||
def test_model_type(model_type):
|
||||
tokens = "./data/tokens.txt"
|
||||
|
||||
params = get_params()
|
||||
|
||||
tokenizer = Tokenizer(tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_type = model_type
|
||||
|
||||
model = get_model(params)
|
||||
generator = model.generator
|
||||
|
||||
num_param = sum([p.numel() for p in generator.parameters()])
|
||||
print(
|
||||
f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
test_model_type("high") # 35.63 M
|
||||
test_model_type("low") # 7.55 M
|
||||
test_model_type("medium") # 23.61 M
|
||||
test_model_type("") # 35.63 M
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -92,9 +92,9 @@ class TextEncoder(torch.nn.Module):
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded hidden representation (B, attention_dim, T_text).
|
||||
Tensor: Projected mean tensor (B, attention_dim, T_text).
|
||||
Tensor: Projected scale tensor (B, attention_dim, T_text).
|
||||
Tensor: Encoded hidden representation (B, embed_dim, T_text).
|
||||
Tensor: Projected mean tensor (B, embed_dim, T_text).
|
||||
Tensor: Projected scale tensor (B, embed_dim, T_text).
|
||||
Tensor: Mask tensor for input tensor (B, 1, T_text).
|
||||
|
||||
"""
|
||||
@ -108,6 +108,7 @@ class TextEncoder(torch.nn.Module):
|
||||
|
||||
# encoder assume the channel last (B, T_text, embed_dim)
|
||||
x = self.encoder(x, key_padding_mask=pad_mask)
|
||||
# Note: attention_dim == embed_dim
|
||||
|
||||
# convert the channel first (B, embed_dim, T_text)
|
||||
x = x.transpose(1, 2)
|
||||
|
@ -153,6 +153,15 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="",
|
||||
help="""If not empty, valid values are: low, medium, high.
|
||||
It controls the model size. low -> runs faster.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -189,15 +198,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- encoder_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- warm_step: The warmup period that dictates the decay of the
|
||||
scale on "simple" (un-pruned) loss.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
@ -278,6 +278,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
vocab_size=params.vocab_size,
|
||||
feature_dim=params.feature_dim,
|
||||
sampling_rate=params.sampling_rate,
|
||||
model_type=params.model_type,
|
||||
mel_loss_params=mel_loss_params,
|
||||
lambda_adv=params.lambda_adv,
|
||||
lambda_mel=params.lambda_mel,
|
||||
@ -363,7 +364,7 @@ def train_one_epoch(
|
||||
model.train()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
||||
# used to summary the stats over iterations in one epoch
|
||||
# used to track the stats over iterations in one epoch
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
saved_bad_model = False
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
"""VITS module for GAN-TTS task."""
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -38,6 +39,36 @@ AVAILABLE_DISCRIMINATORS = {
|
||||
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
|
||||
}
|
||||
|
||||
LOW_CONFIG = {
|
||||
"hidden_channels": 96,
|
||||
"decoder_upsample_scales": (8, 8, 4),
|
||||
"decoder_channels": 256,
|
||||
"decoder_upsample_kernel_sizes": (16, 16, 8),
|
||||
"decoder_resblock_kernel_sizes": (3, 5, 7),
|
||||
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
|
||||
"text_encoder_cnn_module_kernel": 3,
|
||||
}
|
||||
|
||||
MEDIUM_CONFIG = {
|
||||
"hidden_channels": 192,
|
||||
"decoder_upsample_scales": (8, 8, 4),
|
||||
"decoder_channels": 256,
|
||||
"decoder_upsample_kernel_sizes": (16, 16, 8),
|
||||
"decoder_resblock_kernel_sizes": (3, 5, 7),
|
||||
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
|
||||
"text_encoder_cnn_module_kernel": 3,
|
||||
}
|
||||
|
||||
HIGH_CONFIG = {
|
||||
"hidden_channels": 192,
|
||||
"decoder_upsample_scales": (8, 8, 2, 2),
|
||||
"decoder_channels": 512,
|
||||
"decoder_upsample_kernel_sizes": (16, 16, 4, 4),
|
||||
"decoder_resblock_kernel_sizes": (3, 7, 11),
|
||||
"decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||
"text_encoder_cnn_module_kernel": 3,
|
||||
}
|
||||
|
||||
|
||||
class VITS(nn.Module):
|
||||
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
|
||||
@ -49,6 +80,7 @@ class VITS(nn.Module):
|
||||
feature_dim: int = 513,
|
||||
sampling_rate: int = 22050,
|
||||
generator_type: str = "vits_generator",
|
||||
model_type: str = "",
|
||||
generator_params: Dict[str, Any] = {
|
||||
"hidden_channels": 192,
|
||||
"spks": None,
|
||||
@ -155,12 +187,13 @@ class VITS(nn.Module):
|
||||
"""Initialize VITS module.
|
||||
|
||||
Args:
|
||||
idim (int): Input vocabrary size.
|
||||
idim (int): Input vocabulary size.
|
||||
odim (int): Acoustic feature dimension. The actual output channels will
|
||||
be 1 since VITS is the end-to-end text-to-wave model but for the
|
||||
compatibility odim is used to indicate the acoustic feature dimension.
|
||||
sampling_rate (int): Sampling rate, not used for the training but it will
|
||||
be referred in saving waveform during the inference.
|
||||
model_type (str): If not empty, must be one of: low, medium, high
|
||||
generator_type (str): Generator type.
|
||||
generator_params (Dict[str, Any]): Parameter dict for generator.
|
||||
discriminator_type (str): Discriminator type.
|
||||
@ -181,6 +214,24 @@ class VITS(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
generator_params = copy.deepcopy(generator_params)
|
||||
discriminator_params = copy.deepcopy(discriminator_params)
|
||||
generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params)
|
||||
discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params)
|
||||
feat_match_loss_params = copy.deepcopy(feat_match_loss_params)
|
||||
mel_loss_params = copy.deepcopy(mel_loss_params)
|
||||
|
||||
if model_type != "":
|
||||
assert model_type in ("low", "medium", "high"), model_type
|
||||
if model_type == "low":
|
||||
generator_params.update(LOW_CONFIG)
|
||||
elif model_type == "medium":
|
||||
generator_params.update(MEDIUM_CONFIG)
|
||||
elif model_type == "high":
|
||||
generator_params.update(HIGH_CONFIG)
|
||||
else:
|
||||
raise ValueError(f"Unknown model_type: ${model_type}")
|
||||
|
||||
# define modules
|
||||
generator_class = AVAILABLE_GENERATERS[generator_type]
|
||||
if generator_type == "vits_generator":
|
||||
|
Loading…
x
Reference in New Issue
Block a user