mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Support different tts model types.
low (quality) -> runs faster. high (quality) -> runs slower
This commit is contained in:
parent
ae61bd4090
commit
b33d3820db
@ -54,7 +54,7 @@ fi
|
|||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
log "Stage 1: Prepare LJSpeech manifest"
|
log "Stage 1: Prepare LJSpeech manifest"
|
||||||
# We assume that you have downloaded the LJSpeech corpus
|
# We assume that you have downloaded the LJSpeech corpus
|
||||||
# to $dl_dir/LJSpeech
|
# to $dl_dir/LJSpeech-1.1
|
||||||
mkdir -p data/manifests
|
mkdir -p data/manifests
|
||||||
if [ ! -e data/manifests/.ljspeech.done ]; then
|
if [ ! -e data/manifests/.ljspeech.done ]; then
|
||||||
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
|
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
|
||||||
|
@ -25,9 +25,8 @@ Export the model to ONNX:
|
|||||||
--exp-dir vits/exp \
|
--exp-dir vits/exp \
|
||||||
--tokens data/tokens.txt
|
--tokens data/tokens.txt
|
||||||
|
|
||||||
It will generate two files inside vits/exp:
|
It will generate one file inside vits/exp:
|
||||||
- vits-epoch-1000.onnx
|
- vits-epoch-1000.onnx
|
||||||
- vits-epoch-1000.int8.onnx (quantizated model)
|
|
||||||
|
|
||||||
See ./test_onnx.py for how to use the exported ONNX models.
|
See ./test_onnx.py for how to use the exported ONNX models.
|
||||||
"""
|
"""
|
||||||
@ -40,7 +39,6 @@ from typing import Dict, Tuple
|
|||||||
import onnx
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
from train import get_model, get_params
|
from train import get_model, get_params
|
||||||
|
|
||||||
@ -75,6 +73,15 @@ def get_parser():
|
|||||||
help="""Path to vocabulary.""",
|
help="""Path to vocabulary.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-type",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""If not empty, valid values are: low, medium, high.
|
||||||
|
It controls the model size. low -> runs faster.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -240,7 +247,7 @@ def main():
|
|||||||
model = OnnxModel(model=model)
|
model = OnnxModel(model=model)
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"generator parameters: {num_param}")
|
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
|
||||||
|
|
||||||
suffix = f"epoch-{params.epoch}"
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
@ -256,18 +263,6 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"Exported generator to {model_filename}")
|
logging.info(f"Exported generator to {model_filename}")
|
||||||
|
|
||||||
# Generate int8 quantization models
|
|
||||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
|
||||||
|
|
||||||
logging.info("Generate int8 quantization models")
|
|
||||||
|
|
||||||
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
|
|
||||||
quantize_dynamic(
|
|
||||||
model_input=model_filename,
|
|
||||||
model_output=model_filename_int8,
|
|
||||||
weight_type=QuantType.QUInt8,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
@ -189,7 +189,7 @@ class VITSGenerator(torch.nn.Module):
|
|||||||
self.upsample_factor = int(np.prod(decoder_upsample_scales))
|
self.upsample_factor = int(np.prod(decoder_upsample_scales))
|
||||||
self.spks = None
|
self.spks = None
|
||||||
if spks is not None and spks > 1:
|
if spks is not None and spks > 1:
|
||||||
assert global_channels > 0
|
assert global_channels > 0, global_channels
|
||||||
self.spks = spks
|
self.spks = spks
|
||||||
self.global_emb = torch.nn.Embedding(spks, global_channels)
|
self.global_emb = torch.nn.Embedding(spks, global_channels)
|
||||||
self.spk_embed_dim = None
|
self.spk_embed_dim = None
|
||||||
|
@ -72,6 +72,15 @@ def get_parser():
|
|||||||
help="""Path to vocabulary.""",
|
help="""Path to vocabulary.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-type",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""If not empty, valid values are: low, medium, high.
|
||||||
|
It controls the model size. low -> runs faster.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -94,6 +103,7 @@ def infer_dataset(
|
|||||||
tokenizer:
|
tokenizer:
|
||||||
Used to convert text to phonemes.
|
Used to convert text to phonemes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Background worker save audios to disk.
|
# Background worker save audios to disk.
|
||||||
def _save_worker(
|
def _save_worker(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
51
egs/ljspeech/TTS/vits/test_model.py
Executable file
51
egs/ljspeech/TTS/vits/test_model.py
Executable file
@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from train import get_model, get_params
|
||||||
|
from vits import VITS
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_type(model_type):
|
||||||
|
tokens = "./data/tokens.txt"
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(tokens)
|
||||||
|
params.blank_id = tokenizer.pad_id
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
params.model_type = model_type
|
||||||
|
|
||||||
|
model = get_model(params)
|
||||||
|
generator = model.generator
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in generator.parameters()])
|
||||||
|
print(
|
||||||
|
f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_model_type("high") # 35.63 M
|
||||||
|
test_model_type("low") # 7.55 M
|
||||||
|
test_model_type("medium") # 23.61 M
|
||||||
|
test_model_type("") # 35.63 M
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -92,9 +92,9 @@ class TextEncoder(torch.nn.Module):
|
|||||||
x_lengths (Tensor): Length tensor (B,).
|
x_lengths (Tensor): Length tensor (B,).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: Encoded hidden representation (B, attention_dim, T_text).
|
Tensor: Encoded hidden representation (B, embed_dim, T_text).
|
||||||
Tensor: Projected mean tensor (B, attention_dim, T_text).
|
Tensor: Projected mean tensor (B, embed_dim, T_text).
|
||||||
Tensor: Projected scale tensor (B, attention_dim, T_text).
|
Tensor: Projected scale tensor (B, embed_dim, T_text).
|
||||||
Tensor: Mask tensor for input tensor (B, 1, T_text).
|
Tensor: Mask tensor for input tensor (B, 1, T_text).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -108,6 +108,7 @@ class TextEncoder(torch.nn.Module):
|
|||||||
|
|
||||||
# encoder assume the channel last (B, T_text, embed_dim)
|
# encoder assume the channel last (B, T_text, embed_dim)
|
||||||
x = self.encoder(x, key_padding_mask=pad_mask)
|
x = self.encoder(x, key_padding_mask=pad_mask)
|
||||||
|
# Note: attention_dim == embed_dim
|
||||||
|
|
||||||
# convert the channel first (B, embed_dim, T_text)
|
# convert the channel first (B, embed_dim, T_text)
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
|
@ -153,6 +153,15 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-type",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""If not empty, valid values are: low, medium, high.
|
||||||
|
It controls the model size. low -> runs faster.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -189,15 +198,6 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- feature_dim: The model input dim. It has to match the one used
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
in computing features.
|
in computing features.
|
||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
|
||||||
|
|
||||||
- encoder_dim: Hidden dim for multi-head attention model.
|
|
||||||
|
|
||||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
|
||||||
|
|
||||||
- warm_step: The warmup period that dictates the decay of the
|
|
||||||
scale on "simple" (un-pruned) loss.
|
|
||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
@ -278,6 +278,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
feature_dim=params.feature_dim,
|
feature_dim=params.feature_dim,
|
||||||
sampling_rate=params.sampling_rate,
|
sampling_rate=params.sampling_rate,
|
||||||
|
model_type=params.model_type,
|
||||||
mel_loss_params=mel_loss_params,
|
mel_loss_params=mel_loss_params,
|
||||||
lambda_adv=params.lambda_adv,
|
lambda_adv=params.lambda_adv,
|
||||||
lambda_mel=params.lambda_mel,
|
lambda_mel=params.lambda_mel,
|
||||||
@ -363,7 +364,7 @@ def train_one_epoch(
|
|||||||
model.train()
|
model.train()
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
|
||||||
# used to summary the stats over iterations in one epoch
|
# used to track the stats over iterations in one epoch
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
saved_bad_model = False
|
saved_bad_model = False
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
"""VITS module for GAN-TTS task."""
|
"""VITS module for GAN-TTS task."""
|
||||||
|
|
||||||
|
import copy
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -38,6 +39,36 @@ AVAILABLE_DISCRIMINATORS = {
|
|||||||
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
|
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOW_CONFIG = {
|
||||||
|
"hidden_channels": 96,
|
||||||
|
"decoder_upsample_scales": (8, 8, 4),
|
||||||
|
"decoder_channels": 256,
|
||||||
|
"decoder_upsample_kernel_sizes": (16, 16, 8),
|
||||||
|
"decoder_resblock_kernel_sizes": (3, 5, 7),
|
||||||
|
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
|
||||||
|
"text_encoder_cnn_module_kernel": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
MEDIUM_CONFIG = {
|
||||||
|
"hidden_channels": 192,
|
||||||
|
"decoder_upsample_scales": (8, 8, 4),
|
||||||
|
"decoder_channels": 256,
|
||||||
|
"decoder_upsample_kernel_sizes": (16, 16, 8),
|
||||||
|
"decoder_resblock_kernel_sizes": (3, 5, 7),
|
||||||
|
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
|
||||||
|
"text_encoder_cnn_module_kernel": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
HIGH_CONFIG = {
|
||||||
|
"hidden_channels": 192,
|
||||||
|
"decoder_upsample_scales": (8, 8, 2, 2),
|
||||||
|
"decoder_channels": 512,
|
||||||
|
"decoder_upsample_kernel_sizes": (16, 16, 4, 4),
|
||||||
|
"decoder_resblock_kernel_sizes": (3, 7, 11),
|
||||||
|
"decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||||
|
"text_encoder_cnn_module_kernel": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class VITS(nn.Module):
|
class VITS(nn.Module):
|
||||||
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
|
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
|
||||||
@ -49,6 +80,7 @@ class VITS(nn.Module):
|
|||||||
feature_dim: int = 513,
|
feature_dim: int = 513,
|
||||||
sampling_rate: int = 22050,
|
sampling_rate: int = 22050,
|
||||||
generator_type: str = "vits_generator",
|
generator_type: str = "vits_generator",
|
||||||
|
model_type: str = "",
|
||||||
generator_params: Dict[str, Any] = {
|
generator_params: Dict[str, Any] = {
|
||||||
"hidden_channels": 192,
|
"hidden_channels": 192,
|
||||||
"spks": None,
|
"spks": None,
|
||||||
@ -155,12 +187,13 @@ class VITS(nn.Module):
|
|||||||
"""Initialize VITS module.
|
"""Initialize VITS module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
idim (int): Input vocabrary size.
|
idim (int): Input vocabulary size.
|
||||||
odim (int): Acoustic feature dimension. The actual output channels will
|
odim (int): Acoustic feature dimension. The actual output channels will
|
||||||
be 1 since VITS is the end-to-end text-to-wave model but for the
|
be 1 since VITS is the end-to-end text-to-wave model but for the
|
||||||
compatibility odim is used to indicate the acoustic feature dimension.
|
compatibility odim is used to indicate the acoustic feature dimension.
|
||||||
sampling_rate (int): Sampling rate, not used for the training but it will
|
sampling_rate (int): Sampling rate, not used for the training but it will
|
||||||
be referred in saving waveform during the inference.
|
be referred in saving waveform during the inference.
|
||||||
|
model_type (str): If not empty, must be one of: low, medium, high
|
||||||
generator_type (str): Generator type.
|
generator_type (str): Generator type.
|
||||||
generator_params (Dict[str, Any]): Parameter dict for generator.
|
generator_params (Dict[str, Any]): Parameter dict for generator.
|
||||||
discriminator_type (str): Discriminator type.
|
discriminator_type (str): Discriminator type.
|
||||||
@ -181,6 +214,24 @@ class VITS(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
generator_params = copy.deepcopy(generator_params)
|
||||||
|
discriminator_params = copy.deepcopy(discriminator_params)
|
||||||
|
generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params)
|
||||||
|
discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params)
|
||||||
|
feat_match_loss_params = copy.deepcopy(feat_match_loss_params)
|
||||||
|
mel_loss_params = copy.deepcopy(mel_loss_params)
|
||||||
|
|
||||||
|
if model_type != "":
|
||||||
|
assert model_type in ("low", "medium", "high"), model_type
|
||||||
|
if model_type == "low":
|
||||||
|
generator_params.update(LOW_CONFIG)
|
||||||
|
elif model_type == "medium":
|
||||||
|
generator_params.update(MEDIUM_CONFIG)
|
||||||
|
elif model_type == "high":
|
||||||
|
generator_params.update(HIGH_CONFIG)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model_type: ${model_type}")
|
||||||
|
|
||||||
# define modules
|
# define modules
|
||||||
generator_class = AVAILABLE_GENERATERS[generator_type]
|
generator_class = AVAILABLE_GENERATERS[generator_type]
|
||||||
if generator_type == "vits_generator":
|
if generator_type == "vits_generator":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user