mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Support different tts model types. (#1541)
This commit is contained in:
parent
959906e9dc
commit
81f518ea7c
@ -56,7 +56,8 @@ Training
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt
|
||||
--tokens data/tokens.txt \
|
||||
--model-type high \
|
||||
--max-duration 500
|
||||
|
||||
.. note::
|
||||
@ -64,6 +65,11 @@ Training
|
||||
You can adjust the hyper-parameters to control the size of the VITS model and
|
||||
the training configurations. For more details, please run ``./vits/train.py --help``.
|
||||
|
||||
.. warning::
|
||||
|
||||
If you want a model that runs faster on CPU, please use ``--model-type low``
|
||||
or ``--model-type medium``.
|
||||
|
||||
.. note::
|
||||
|
||||
The training can take a long time (usually a couple of days).
|
||||
@ -95,8 +101,8 @@ training part first. It will save the ground-truth and generated wavs to the dir
|
||||
Export models
|
||||
-------------
|
||||
|
||||
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
|
||||
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
|
||||
Currently we only support ONNX model exporting. It will generate one file in the given ``exp-dir``:
|
||||
``vits-epoch-*.onnx``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@ -120,4 +126,7 @@ Download pretrained models
|
||||
If you don't want to train from scratch, you can download the pretrained models
|
||||
by visiting the following link:
|
||||
|
||||
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
|
||||
- ``--model-type=high``: `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_
|
||||
- ``--model-type=medium``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>`_
|
||||
- ``--model-type=low``: `<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>`_
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Introduction
|
||||
|
||||
This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
|
||||
A transcription is provided for each clip.
|
||||
This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
|
||||
A transcription is provided for each clip.
|
||||
Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours.
|
||||
|
||||
The texts were published between 1884 and 1964, and are in the public domain.
|
||||
The texts were published between 1884 and 1964, and are in the public domain.
|
||||
The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain.
|
||||
|
||||
The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/).
|
||||
@ -35,4 +35,69 @@ To inference, use:
|
||||
--exp-dir vits/exp \
|
||||
--epoch 1000 \
|
||||
--tokens data/tokens.txt
|
||||
```
|
||||
```
|
||||
|
||||
## Quality vs speed
|
||||
|
||||
If you feel that the trained model is slow at runtime, you can specify the
|
||||
argument `--model-type` during training. Possible values are:
|
||||
|
||||
- `low`, means **low** quality. The resulting model is very small in file size
|
||||
and runs very fast. The following is a wave file generatd by a `low` quality model
|
||||
|
||||
https://github.com/k2-fsa/icefall/assets/5284924/d5758c24-470d-40ee-b089-e57fcba81633
|
||||
|
||||
The text is `Ask not what your country can do for you; ask what you can do for your country.`
|
||||
|
||||
The exported onnx model has a file size of ``26.8 MB`` (float32).
|
||||
|
||||
- `medium`, means **medium** quality.
|
||||
The following is a wave file generatd by a `medium` quality model
|
||||
|
||||
https://github.com/k2-fsa/icefall/assets/5284924/b199d960-3665-4d0d-9ae9-a1bb69cbc8ac
|
||||
|
||||
The text is `Ask not what your country can do for you; ask what you can do for your country.`
|
||||
|
||||
The exported onnx model has a file size of ``70.9 MB`` (float32).
|
||||
|
||||
- `high`, means **high** quality. This is the default value.
|
||||
|
||||
The following is a wave file generatd by a `high` quality model
|
||||
|
||||
https://github.com/k2-fsa/icefall/assets/5284924/b39f3048-73a6-4267-bf95-df5abfdb28fc
|
||||
|
||||
The text is `Ask not what your country can do for you; ask what you can do for your country.`
|
||||
|
||||
The exported onnx model has a file size of ``113 MB`` (float32).
|
||||
|
||||
|
||||
A pre-trained `low` model trained using 4xV100 32GB GPU with the following command can be found at
|
||||
<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-low-2024-03-12>
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
./vits/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 1601 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir vits/exp \
|
||||
--model-type low \
|
||||
--max-duration 800
|
||||
```
|
||||
|
||||
A pre-trained `medium` model trained using 4xV100 32GB GPU with the following command can be found at
|
||||
<https://huggingface.co/csukuangfj/icefall-tts-ljspeech-vits-medium-2024-03-12>
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
./vits/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 1000 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir vits/exp-medium \
|
||||
--model-type medium \
|
||||
--max-duration 500
|
||||
|
||||
# (Note it is killed after `epoch-820.pt`)
|
||||
```
|
||||
|
@ -54,7 +54,7 @@ fi
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare LJSpeech manifest"
|
||||
# We assume that you have downloaded the LJSpeech corpus
|
||||
# to $dl_dir/LJSpeech
|
||||
# to $dl_dir/LJSpeech-1.1
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.ljspeech.done ]; then
|
||||
lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
|
||||
|
@ -25,9 +25,8 @@ Export the model to ONNX:
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt
|
||||
|
||||
It will generate two files inside vits/exp:
|
||||
It will generate one file inside vits/exp:
|
||||
- vits-epoch-1000.onnx
|
||||
- vits-epoch-1000.int8.onnx (quantizated model)
|
||||
|
||||
See ./test_onnx.py for how to use the exported ONNX models.
|
||||
"""
|
||||
@ -40,7 +39,6 @@ from typing import Dict, Tuple
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
|
||||
@ -75,6 +73,16 @@ def get_parser():
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="high",
|
||||
choices=["low", "medium", "high"],
|
||||
help="""If not empty, valid values are: low, medium, high.
|
||||
It controls the model size. low -> runs faster.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -136,7 +144,7 @@ class OnnxModel(nn.Module):
|
||||
Return a tuple containing:
|
||||
- audio, generated wavform tensor, (B, T_wav)
|
||||
"""
|
||||
audio, _, _ = self.model.inference(
|
||||
audio, _, _ = self.model.generator.inference(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
noise_scale=noise_scale,
|
||||
@ -198,6 +206,11 @@ def export_model_onnx(
|
||||
},
|
||||
)
|
||||
|
||||
if model.model.spks is None:
|
||||
num_speakers = 1
|
||||
else:
|
||||
num_speakers = model.model.spks
|
||||
|
||||
meta_data = {
|
||||
"model_type": "vits",
|
||||
"version": "1",
|
||||
@ -206,8 +219,8 @@ def export_model_onnx(
|
||||
"language": "English",
|
||||
"voice": "en-us", # Choose your language appropriately
|
||||
"has_espeak": 1,
|
||||
"n_speakers": 1,
|
||||
"sample_rate": 22050, # Must match the real sample rate
|
||||
"n_speakers": num_speakers,
|
||||
"sample_rate": model.model.sampling_rate, # Must match the real sample rate
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
@ -233,14 +246,13 @@ def main():
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
|
||||
model = model.generator
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
model = OnnxModel(model=model)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"generator parameters: {num_param}")
|
||||
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
|
||||
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
@ -256,18 +268,6 @@ def main():
|
||||
)
|
||||
logging.info(f"Exported generator to {model_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=model_filename,
|
||||
model_output=model_filename_int8,
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
@ -189,7 +189,7 @@ class VITSGenerator(torch.nn.Module):
|
||||
self.upsample_factor = int(np.prod(decoder_upsample_scales))
|
||||
self.spks = None
|
||||
if spks is not None and spks > 1:
|
||||
assert global_channels > 0
|
||||
assert global_channels > 0, global_channels
|
||||
self.spks = spks
|
||||
self.global_emb = torch.nn.Embedding(spks, global_channels)
|
||||
self.spk_embed_dim = None
|
||||
|
@ -72,6 +72,16 @@ def get_parser():
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="high",
|
||||
choices=["low", "medium", "high"],
|
||||
help="""If not empty, valid values are: low, medium, high.
|
||||
It controls the model size. low -> runs faster.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -94,6 +104,7 @@ def infer_dataset(
|
||||
tokenizer:
|
||||
Used to convert text to phonemes.
|
||||
"""
|
||||
|
||||
# Background worker save audios to disk.
|
||||
def _save_worker(
|
||||
batch_size: int,
|
||||
|
50
egs/ljspeech/TTS/vits/test_model.py
Executable file
50
egs/ljspeech/TTS/vits/test_model.py
Executable file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
from vits import VITS
|
||||
|
||||
|
||||
def test_model_type(model_type):
|
||||
tokens = "./data/tokens.txt"
|
||||
|
||||
params = get_params()
|
||||
|
||||
tokenizer = Tokenizer(tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_type = model_type
|
||||
|
||||
model = get_model(params)
|
||||
generator = model.generator
|
||||
|
||||
num_param = sum([p.numel() for p in generator.parameters()])
|
||||
print(
|
||||
f"{model_type}: generator parameters: {num_param}, or {num_param/1000/1000} M"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
test_model_type("high") # 35.63 M
|
||||
test_model_type("low") # 7.55 M
|
||||
test_model_type("medium") # 23.61 M
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -54,6 +54,20 @@ def get_parser():
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
default="Ask not what your country can do for you; ask what you can do for your country.",
|
||||
help="Text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-filename",
|
||||
type=str,
|
||||
default="test_onnx.wav",
|
||||
help="Filename to save the generated wave file.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -61,7 +75,7 @@ class OnnxModel:
|
||||
def __init__(self, model_filename: str):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 4
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
@ -72,6 +86,9 @@ class OnnxModel:
|
||||
)
|
||||
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||
|
||||
metadata = self.model.get_modelmeta().custom_metadata_map
|
||||
self.sample_rate = int(metadata["sample_rate"])
|
||||
|
||||
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -101,13 +118,14 @@ class OnnxModel:
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
tokenizer = Tokenizer(args.tokens)
|
||||
|
||||
logging.info("About to create onnx model")
|
||||
model = OnnxModel(args.model_filename)
|
||||
|
||||
text = "I went there to see the land, the people and how their system works, end quote."
|
||||
text = args.text
|
||||
tokens = tokenizer.texts_to_token_ids(
|
||||
[text], intersperse_blank=True, add_sos=True, add_eos=True
|
||||
)
|
||||
@ -115,8 +133,9 @@ def main():
|
||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
||||
audio = model(tokens, tokens_lens) # (1, T')
|
||||
|
||||
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
|
||||
logging.info("Saved to test_onnx.wav")
|
||||
output_filename = args.output_filename
|
||||
torchaudio.save(output_filename, audio, sample_rate=model.sample_rate)
|
||||
logging.info(f"Saved to {output_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -92,9 +92,9 @@ class TextEncoder(torch.nn.Module):
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded hidden representation (B, attention_dim, T_text).
|
||||
Tensor: Projected mean tensor (B, attention_dim, T_text).
|
||||
Tensor: Projected scale tensor (B, attention_dim, T_text).
|
||||
Tensor: Encoded hidden representation (B, embed_dim, T_text).
|
||||
Tensor: Projected mean tensor (B, embed_dim, T_text).
|
||||
Tensor: Projected scale tensor (B, embed_dim, T_text).
|
||||
Tensor: Mask tensor for input tensor (B, 1, T_text).
|
||||
|
||||
"""
|
||||
@ -108,6 +108,7 @@ class TextEncoder(torch.nn.Module):
|
||||
|
||||
# encoder assume the channel last (B, T_text, embed_dim)
|
||||
x = self.encoder(x, key_padding_mask=pad_mask)
|
||||
# Note: attention_dim == embed_dim
|
||||
|
||||
# convert the channel first (B, embed_dim, T_text)
|
||||
x = x.transpose(1, 2)
|
||||
|
@ -18,7 +18,15 @@ import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import tacotron_cleaner.cleaners
|
||||
from piper_phonemize import phonemize_espeak
|
||||
|
||||
try:
|
||||
from piper_phonemize import phonemize_espeak
|
||||
except Exception as ex:
|
||||
raise RuntimeError(
|
||||
f"{ex}\nPlease follow instructions in "
|
||||
"../prepare.sh to install piper-phonemize"
|
||||
)
|
||||
|
||||
from utils import intersperse
|
||||
|
||||
|
||||
|
@ -153,6 +153,16 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="high",
|
||||
choices=["low", "medium", "high"],
|
||||
help="""If not empty, valid values are: low, medium, high.
|
||||
It controls the model size. low -> runs faster.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -189,15 +199,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- encoder_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- warm_step: The warmup period that dictates the decay of the
|
||||
scale on "simple" (un-pruned) loss.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
@ -278,6 +279,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
vocab_size=params.vocab_size,
|
||||
feature_dim=params.feature_dim,
|
||||
sampling_rate=params.sampling_rate,
|
||||
model_type=params.model_type,
|
||||
mel_loss_params=mel_loss_params,
|
||||
lambda_adv=params.lambda_adv,
|
||||
lambda_mel=params.lambda_mel,
|
||||
@ -363,7 +365,7 @@ def train_one_epoch(
|
||||
model.train()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
||||
# used to summary the stats over iterations in one epoch
|
||||
# used to track the stats over iterations in one epoch
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
saved_bad_model = False
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
"""VITS module for GAN-TTS task."""
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -38,6 +39,36 @@ AVAILABLE_DISCRIMINATORS = {
|
||||
"hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
|
||||
}
|
||||
|
||||
LOW_CONFIG = {
|
||||
"hidden_channels": 96,
|
||||
"decoder_upsample_scales": (8, 8, 4),
|
||||
"decoder_channels": 256,
|
||||
"decoder_upsample_kernel_sizes": (16, 16, 8),
|
||||
"decoder_resblock_kernel_sizes": (3, 5, 7),
|
||||
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
|
||||
"text_encoder_cnn_module_kernel": 3,
|
||||
}
|
||||
|
||||
MEDIUM_CONFIG = {
|
||||
"hidden_channels": 192,
|
||||
"decoder_upsample_scales": (8, 8, 4),
|
||||
"decoder_channels": 256,
|
||||
"decoder_upsample_kernel_sizes": (16, 16, 8),
|
||||
"decoder_resblock_kernel_sizes": (3, 5, 7),
|
||||
"decoder_resblock_dilations": ((1, 2), (2, 6), (3, 12)),
|
||||
"text_encoder_cnn_module_kernel": 3,
|
||||
}
|
||||
|
||||
HIGH_CONFIG = {
|
||||
"hidden_channels": 192,
|
||||
"decoder_upsample_scales": (8, 8, 2, 2),
|
||||
"decoder_channels": 512,
|
||||
"decoder_upsample_kernel_sizes": (16, 16, 4, 4),
|
||||
"decoder_resblock_kernel_sizes": (3, 7, 11),
|
||||
"decoder_resblock_dilations": ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
|
||||
"text_encoder_cnn_module_kernel": 5,
|
||||
}
|
||||
|
||||
|
||||
class VITS(nn.Module):
|
||||
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
|
||||
@ -49,6 +80,7 @@ class VITS(nn.Module):
|
||||
feature_dim: int = 513,
|
||||
sampling_rate: int = 22050,
|
||||
generator_type: str = "vits_generator",
|
||||
model_type: str = "",
|
||||
generator_params: Dict[str, Any] = {
|
||||
"hidden_channels": 192,
|
||||
"spks": None,
|
||||
@ -155,12 +187,13 @@ class VITS(nn.Module):
|
||||
"""Initialize VITS module.
|
||||
|
||||
Args:
|
||||
idim (int): Input vocabrary size.
|
||||
idim (int): Input vocabulary size.
|
||||
odim (int): Acoustic feature dimension. The actual output channels will
|
||||
be 1 since VITS is the end-to-end text-to-wave model but for the
|
||||
compatibility odim is used to indicate the acoustic feature dimension.
|
||||
sampling_rate (int): Sampling rate, not used for the training but it will
|
||||
be referred in saving waveform during the inference.
|
||||
model_type (str): If not empty, must be one of: low, medium, high
|
||||
generator_type (str): Generator type.
|
||||
generator_params (Dict[str, Any]): Parameter dict for generator.
|
||||
discriminator_type (str): Discriminator type.
|
||||
@ -181,6 +214,24 @@ class VITS(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
generator_params = copy.deepcopy(generator_params)
|
||||
discriminator_params = copy.deepcopy(discriminator_params)
|
||||
generator_adv_loss_params = copy.deepcopy(generator_adv_loss_params)
|
||||
discriminator_adv_loss_params = copy.deepcopy(discriminator_adv_loss_params)
|
||||
feat_match_loss_params = copy.deepcopy(feat_match_loss_params)
|
||||
mel_loss_params = copy.deepcopy(mel_loss_params)
|
||||
|
||||
if model_type != "":
|
||||
assert model_type in ("low", "medium", "high"), model_type
|
||||
if model_type == "low":
|
||||
generator_params.update(LOW_CONFIG)
|
||||
elif model_type == "medium":
|
||||
generator_params.update(MEDIUM_CONFIG)
|
||||
elif model_type == "high":
|
||||
generator_params.update(HIGH_CONFIG)
|
||||
else:
|
||||
raise ValueError(f"Unknown model_type: ${model_type}")
|
||||
|
||||
# define modules
|
||||
generator_class = AVAILABLE_GENERATERS[generator_type]
|
||||
if generator_type == "vits_generator":
|
||||
|
Loading…
x
Reference in New Issue
Block a user