A TTS recipe VITS on VCTK dataset (#1380)

* init

* isort formatted

* minor updates

* Create shared

* Update prepare_tokens_vctk.py

* Update prepare_tokens_vctk.py

* Update prepare_tokens_vctk.py

* Update prepare.sh

* updated

* Update train.py

* Update train.py

* Update tts_datamodule.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* Update train.py

* fixed formatting issue

* Update infer.py

* removed redundant files

* Create monotonic_align

* removed redundant files

* created symlinks

* Update prepare.sh

* minor adjustments

* Create requirements_tts.txt

* Update requirements_tts.txt

added version constraints

* Update infer.py

* Update infer.py

* Update infer.py

* updated docs

* Update export-onnx.py

* Update export-onnx.py

* Update test_onnx.py

* updated requirements.txt

* Update test_onnx.py

* Update test_onnx.py

* docs updated

* docs fixed

* minor updates
This commit is contained in:
zr_jin 2023-12-06 09:59:19 +08:00 committed by GitHub
parent f08af2fa22
commit 735fb9a73d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 2904 additions and 84 deletions

View File

@ -5,3 +5,4 @@ TTS
:maxdepth: 2
ljspeech/vits
vctk/vits

View File

@ -4,6 +4,10 @@ VITS
This tutorial shows you how to train an VITS model
with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
.. note::
TTS related recipes require packages in ``requirements-tts.txt``.
.. note::
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
@ -27,6 +31,12 @@ To run stage 1 to stage 5, use
Build Monotonic Alignment Search
--------------------------------
.. code-block:: bash
$ ./prepare.sh --stage -1 --stop_stage -1
or
.. code-block:: bash
$ cd vits/monotonic_align
@ -74,7 +84,7 @@ training part first. It will save the ground-truth and generated wavs to the dir
$ ./vits/infer.py \
--epoch 1000 \
--exp-dir vits/exp \
--tokens data/tokens.txt
--tokens data/tokens.txt \
--max-duration 500
.. note::

View File

@ -0,0 +1,125 @@
VITS
===============
This tutorial shows you how to train an VITS model
with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset.
.. note::
TTS related recipes require packages in ``requirements-tts.txt``.
.. note::
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
Data preparation
----------------
.. code-block:: bash
$ cd egs/vctk/TTS
$ ./prepare.sh
To run stage 1 to stage 6, use
.. code-block:: bash
$ ./prepare.sh --stage 1 --stop_stage 6
Build Monotonic Alignment Search
--------------------------------
To build the monotonic alignment search, use the following commands:
.. code-block:: bash
$ ./prepare.sh --stage -1 --stop_stage -1
or
.. code-block:: bash
$ cd vits/monotonic_align
$ python setup.py build_ext --inplace
$ cd ../../
Training
--------
.. code-block:: bash
$ export CUDA_VISIBLE_DEVICES="0,1,2,3"
$ ./vits/train.py \
--world-size 4 \
--num-epochs 1000 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--tokens data/tokens.txt
--max-duration 350
.. note::
You can adjust the hyper-parameters to control the size of the VITS model and
the training configurations. For more details, please run ``./vits/train.py --help``.
.. note::
The training can take a long time (usually a couple of days).
Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``.
Inference
---------
The inference part uses checkpoints saved by the training part, so you have to run the
training part first. It will save the ground-truth and generated wavs to the directory
``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``.
.. code-block:: bash
$ export CUDA_VISIBLE_DEVICES="0"
$ ./vits/infer.py \
--epoch 1000 \
--exp-dir vits/exp \
--tokens data/tokens.txt \
--max-duration 500
.. note::
For more details, please run ``./vits/infer.py --help``.
Export models
-------------
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
.. code-block:: bash
$ ./vits/export-onnx.py \
--epoch 1000 \
--exp-dir vits/exp \
--tokens data/tokens.txt
You can test the exported ONNX model with:
.. code-block:: bash
$ ./vits/test_onnx.py \
--model-filename vits/exp/vits-epoch-1000.onnx \
--tokens data/tokens.txt
Download pretrained models
--------------------------
If you don't want to train from scratch, you can download the pretrained models
by visiting the following link:
- `<https://huggingface.co/zrjin/icefall-tts-vctk-vits-2023-12-05>`_

View File

@ -5,8 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=1
stage=-1
stage=0
stop_stage=100
dl_dir=$PWD/download
@ -25,6 +24,17 @@ log() {
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: build monotonic_align lib"
if [ ! -d vits/monotonic_align/build ]; then
cd vits/monotonic_align
python setup.py build_ext --inplace
cd ../../
else
log "monotonic_align lib already built"
fi
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
@ -113,5 +123,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
--tokens data/tokens.txt
fi
fi

View File

@ -14,7 +14,6 @@ from typing import Optional
import torch
import torch.nn.functional as F
from flow import (
ConvFlow,
DilatedDepthSeparableConv,

View File

@ -180,7 +180,13 @@ def export_model_onnx(
model_filename,
verbose=False,
opset_version=opset_version,
input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"],
input_names=[
"tokens",
"tokens_lens",
"noise_scale",
"noise_scale_dur",
"alpha",
],
output_names=["audio"],
dynamic_axes={
"tokens": {0: "N", 1: "T"},

View File

@ -13,7 +13,6 @@ import math
from typing import Optional, Tuple, Union
import torch
from transform import piecewise_rational_quadratic_transform

View File

@ -16,9 +16,6 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from icefall.utils import make_pad_mask
from duration_predictor import StochasticDurationPredictor
from hifigan import HiFiGANGenerator
from posterior_encoder import PosteriorEncoder
@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock
from text_encoder import TextEncoder
from utils import get_random_segments
from icefall.utils import make_pad_mask
class VITSGenerator(torch.nn.Module):
"""Generator module in VITS, `Conditional Variational Autoencoder

View File

@ -36,13 +36,12 @@ import k2
import torch
import torch.nn as nn
import torchaudio
from train import get_model, get_params
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import LJSpeechTtsDataModule
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
from tts_datamodule import LJSpeechTtsDataModule
def get_parser():
@ -107,12 +106,12 @@ def infer_dataset(
for i in range(batch_size):
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
audio[i:i + 1, :audio_lens[i]],
audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate,
)
torchaudio.save(
str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
audio_pred[i:i + 1, :audio_lens_pred[i]],
audio_pred[i : i + 1, : audio_lens_pred[i]],
sample_rate=params.sampling_rate,
)
@ -144,14 +143,24 @@ def infer_dataset(
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
audio_pred, _, durations = model.inference_batch(
text=tokens, text_lengths=tokens_lens
)
audio_pred = audio_pred.detach().cpu()
# convert to samples
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
audio_lens_pred = (
(durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
)
futures.append(
executor.submit(
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
_save_worker,
batch_size,
cut_ids,
audio,
audio_pred,
audio_lens,
audio_lens_pred,
)
)
@ -160,7 +169,9 @@ def infer_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
# return results
for f in futures:
f.result()

View File

@ -14,7 +14,6 @@ from typing import List, Tuple, Union
import torch
import torch.distributions as D
import torch.nn.functional as F
from lhotse.features.kaldi import Wav2LogFilterBank

View File

@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits.
from typing import Optional, Tuple
import torch
from wavenet import Conv1d, WaveNet
from icefall.utils import make_pad_mask
from wavenet import WaveNet, Conv1d
class PosteriorEncoder(torch.nn.Module):

View File

@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits.
from typing import Optional, Tuple, Union
import torch
from flow import FlipFlow
from wavenet import WaveNet

View File

@ -28,10 +28,10 @@ Use the onnx model to generate a wav:
import argparse
import logging
import onnxruntime as ort
import torch
import torchaudio
from tokenizer import Tokenizer

View File

@ -169,9 +169,7 @@ class Transformer(nn.Module):
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
x = self.encoder(
x, pos_emb, key_padding_mask=key_padding_mask
) # (T, N, C)
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
x = self.after_norm(x)
@ -207,7 +205,9 @@ class TransformerEncoderLayer(nn.Module):
nn.Linear(dim_feedforward, d_model),
)
self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
self.self_attn = RelPositionMultiheadAttention(
d_model, num_heads, dropout=dropout
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
@ -242,7 +242,9 @@ class TransformerEncoderLayer(nn.Module):
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
"""
# macaron style feed-forward module
src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src)))
src = src + self.ff_scale * self.dropout(
self.feed_forward_macaron(self.norm_ff_macaron(src))
)
# multi-head self-attention module
src_attn = self.self_attn(
@ -490,11 +492,17 @@ class RelPositionMultiheadAttention(nn.Module):
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
v = (
v.contiguous()
.view(seq_len, batch_size * self.num_heads, self.head_dim)
.transpose(0, 1)
)
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim)
p = self.linear_pos(pos_emb).view(
pos_emb.size(0), -1, self.num_heads, self.head_dim
)
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
p = p.permute(0, 2, 3, 1)
@ -506,15 +514,23 @@ class RelPositionMultiheadAttention(nn.Module):
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch_size, num_head, seq_len, seq_len)
# compute matrix b and matrix d
matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1)
matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len)
matrix_bd = torch.matmul(
q_with_bias_v, p
) # (batch_size, num_head, seq_len, 2*seq_len-1)
matrix_bd = self.rel_shift(
matrix_bd
) # (batch_size, num_head, seq_len, seq_len)
# (batch_size, num_head, seq_len, seq_len)
attn_output_weights = (matrix_ac + matrix_bd) * scaling
attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len)
attn_output_weights = attn_output_weights.view(
batch_size * self.num_heads, seq_len, seq_len
)
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len)
@ -536,10 +552,16 @@ class RelPositionMultiheadAttention(nn.Module):
# (batch_size * num_head, seq_len, head_dim)
attn_output = torch.bmm(attn_output_weights, v)
assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim)
assert attn_output.shape == (
batch_size * self.num_heads,
seq_len,
self.head_dim,
)
attn_output = (
attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim)
attn_output.transpose(0, 1)
.contiguous()
.view(seq_len, batch_size, self.embed_dim)
)
# (seq_len, batch_size, embed_dim)
attn_output = self.out_proj(attn_output)

View File

@ -78,7 +78,9 @@ class Tokenizer(object):
return token_ids_list
def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True):
def tokens_to_token_ids(
self, tokens_list: List[str], intersperse_blank: bool = True
):
"""
Args:
tokens_list:

View File

@ -18,21 +18,25 @@
import argparse
import logging
import numpy as np
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from torch.optim import Optimizer
from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import LJSpeechTtsDataModule
from utils import MetricsTracker, plot_feature, save_checkpoint
from vits import VITS
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
@ -41,11 +45,6 @@ from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, setup_logger, str2bool
from tokenizer import Tokenizer
from tts_datamodule import LJSpeechTtsDataModule
from utils import MetricsTracker, plot_feature, save_checkpoint
from vits import VITS
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
@ -385,11 +384,12 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
prepare_input(batch, tokenizer, device)
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
batch, tokenizer, device
)
loss_info = MetricsTracker()
loss_info['samples'] = batch_size
loss_info["samples"] = batch_size
try:
with autocast(enabled=params.use_fp16):
@ -446,7 +446,9 @@ def train_one_epoch(
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
@ -482,9 +484,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
@ -492,19 +492,34 @@ def train_one_epoch(
if "returned_sample" in stats_g:
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
tb_writer.add_audio(
"train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
"train/speech_hat_",
speech_hat_,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/speech_", speech_, params.batch_idx_train, params.sampling_rate
"train/speech_",
speech_,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_image(
"train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC'
"train/mel_hat_",
plot_feature(mel_hat_),
params.batch_idx_train,
dataformats="HWC",
)
tb_writer.add_image(
"train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
"train/mel_",
plot_feature(mel_),
params.batch_idx_train,
dataformats="HWC",
)
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
if (
params.batch_idx_train % params.valid_interval == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss")
valid_info, (speech_hat, speech) = compute_validation_loss(
params=params,
@ -523,10 +538,16 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_audio(
"train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate
"train/valdi_speech_hat",
speech_hat,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate
"train/valdi_speech",
speech,
params.batch_idx_train,
params.sampling_rate,
)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
@ -555,11 +576,17 @@ def compute_validation_loss(
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
prepare_input(batch, tokenizer, device)
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device)
loss_info = MetricsTracker()
loss_info['samples'] = batch_size
loss_info["samples"] = batch_size
# forward discriminator
loss_d, stats_d = model(
@ -596,12 +623,17 @@ def compute_validation_loss(
if batch_idx == 0 and rank == 0:
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(
text=tokens[0, :tokens_lens[0].item()]
text=tokens[0, : tokens_lens[0].item()]
)
audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
audio_len_pred = (
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
)
assert audio_len_pred == len(audio_pred), (
audio_len_pred,
len(audio_pred),
)
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
returned_sample = (audio_pred, audio_gt)
if world_size > 1:
@ -632,8 +664,9 @@ def scan_pessimistic_batches_for_oom(
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
prepare_input(batch, tokenizer, device)
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
batch, tokenizer, device
)
try:
# for discriminator
with autocast(enabled=params.use_fp16):

View File

@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
SpeechSynthesisDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,

View File

@ -14,15 +14,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import collections
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn as nn
from lhotse.dataset.sampling.base import CutSampler
from pathlib import Path
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
@ -97,23 +97,23 @@ def plot_feature(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data

View File

@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from generator import VITSGenerator
from hifigan import (
HiFiGANMultiPeriodDiscriminator,
HiFiGANMultiScaleDiscriminator,
@ -25,9 +24,8 @@ from loss import (
KLDivergenceLoss,
MelSpectrogramLoss,
)
from torch.cuda.amp import autocast
from utils import get_segments
from generator import VITSGenerator
AVAILABLE_GENERATERS = {
"vits_generator": VITSGenerator,
@ -42,8 +40,7 @@ AVAILABLE_DISCRIMINATORS = {
class VITS(nn.Module):
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`
"""
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
def __init__(
self,

View File

@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
"""
import math
import logging
import math
from typing import Optional, Tuple
import torch

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the VCTK dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/spectrogram.
"""
import logging
import os
from pathlib import Path
import torch
from lhotse import (
CutSet,
LilcomChunkyWriter,
Spectrogram,
SpectrogramConfig,
load_manifest,
)
from lhotse.audio import RecordingSet
from lhotse.supervision import SupervisionSet
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_spectrogram_vctk():
src_dir = Path("data/manifests")
output_dir = Path("data/spectrogram")
num_jobs = min(32, os.cpu_count())
sampling_rate = 22050
frame_length = 1024 / sampling_rate # (in second)
frame_shift = 256 / sampling_rate # (in second)
use_fft_mag = True
prefix = "vctk"
suffix = "jsonl.gz"
partition = "all"
recordings = load_manifest(
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
).resample(sampling_rate=sampling_rate)
supervisions = load_manifest(
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
)
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=frame_length,
frame_shift=frame_shift,
use_fft_mag=use_fft_mag,
)
extractor = Spectrogram(config)
with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
return
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=recordings, supervisions=supervisions
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_spectrogram_vctk()

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()` in vits/train.py
for usage.
"""
from lhotse import load_manifest_lazy
def main():
path = "./data/spectrogram/vctk_cuts_all.jsonl.gz"
cuts = load_manifest_lazy(path)
cuts.describe()
if __name__ == "__main__":
main()
"""
Cut statistics:
Cuts count: 43873
Total duration (hh:mm:ss) 41:02:18
mean 3.4
std 1.2
min 1.2
25% 2.6
50% 3.1
75% 3.8
99% 8.0
99.5% 9.1
99.9% 12.1
max 16.6
Recordings available: 43873
Features available: 43873
Supervisions available: 43873
SUPERVISION custom fields:
Speech duration statistics:
Total speech duration 41:02:18 100.00% of recording
Total speaking time duration 41:02:18 100.00% of recording
Total silence duration 00:00:01 0.00% of recording
"""

View File

@ -0,0 +1,104 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file reads the texts in given manifest and generates the file that maps tokens to IDs.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict
from lhotse import load_manifest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-file",
type=Path,
default=Path("data/spectrogram/vctk_cuts_all.jsonl.gz"),
help="Path to the manifest file",
)
parser.add_argument(
"--tokens",
type=Path,
default=Path("data/tokens.txt"),
help="Path to the tokens",
)
return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = [
"<blk>", # 0 for blank
"<sos/eos>", # 1 for sos and eos symbols.
"<unk>", # 2 for OOV
]
all_tokens = set()
cut_set = load_manifest(manifest_file)
for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
for t in cut.tokens:
all_tokens.add(t)
all_tokens = extra_tokens + list(all_tokens)
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
return token2id
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
manifest_file = Path(args.manifest_file)
out_file = Path(args.tokens)
token2id = get_token2id(manifest_file)
write_mapping(out_file, token2id)

View File

@ -0,0 +1,61 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file reads the texts in given manifest and save the new cuts with phoneme tokens.
"""
import logging
from pathlib import Path
import g2p_en
import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest
from tqdm.auto import tqdm
def prepare_tokens_vctk():
output_dir = Path("data/spectrogram")
prefix = "vctk"
suffix = "jsonl.gz"
partition = "all"
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
g2p = g2p_en.G2p()
new_cuts = []
for cut in tqdm(cut_set):
# Each cut only contains one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
text = cut.supervisions[0].text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
cut.tokens = g2p(text)
new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts)
new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
prepare_tokens_vctk()

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
We will add more checks later if needed.
Usage example:
python3 ./local/validate_manifest.py \
./data/spectrogram/ljspeech_cuts_all.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset.speech_synthesis import validate_for_tts
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)
return parser.parse_args()
def main():
args = get_args()
manifest = args.manifest
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest_lazy(manifest)
assert isinstance(cut_set, CutSet)
validate_for_tts(cut_set)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

131
egs/vctk/TTS/prepare.sh Executable file
View File

@ -0,0 +1,131 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=0
stop_stage=100
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: build monotonic_align lib"
if [ ! -d vits/monotonic_align/build ]; then
cd vits/monotonic_align
python setup.py build_ext --inplace
cd ../../
else
log "monotonic_align lib already built"
fi
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/VCTK,
# you can create a symlink
#
# ln -sfv /path/to/VCTK $dl_dir/VCTK
#
if [ ! -d $dl_dir/VCTK ]; then
lhotse download vctk $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare VCTK manifest"
# We assume that you have downloaded the VCTK corpus
# to $dl_dir/VCTK
mkdir -p data/manifests
if [ ! -e data/manifests/.vctk.done ]; then
lhotse prepare vctk --use-edinburgh-vctk-url true $dl_dir/VCTK data/manifests
touch data/manifests/.vctk.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compute spectrogram for VCTK"
mkdir -p data/spectrogram
if [ ! -e data/spectrogram/.vctk.done ]; then
./local/compute_spectrogram_vctk.py
touch data/spectrogram/.vctk.done
fi
if [ ! -e data/spectrogram/.vctk-validated.done ]; then
log "Validating data/fbank for VCTK"
./local/validate_manifest.py \
data/spectrogram/vctk_cuts_all.jsonl.gz
touch data/spectrogram/.vctk-validated.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for VCTK"
if [ ! -e data/spectrogram/.vctk_with_token.done ]; then
./local/prepare_tokens_vctk.py
mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \
data/spectrogram/vctk_cuts_all.jsonl.gz
touch data/spectrogram/.vctk_with_token.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Split the VCTK cuts into train, valid and test sets"
if [ ! -e data/spectrogram/.vctk_split.done ]; then
lhotse subset --last 600 \
data/spectrogram/vctk_cuts_all.jsonl.gz \
data/spectrogram/vctk_cuts_validtest.jsonl.gz
lhotse subset --first 100 \
data/spectrogram/vctk_cuts_validtest.jsonl.gz \
data/spectrogram/vctk_cuts_valid.jsonl.gz
lhotse subset --last 500 \
data/spectrogram/vctk_cuts_validtest.jsonl.gz \
data/spectrogram/vctk_cuts_test.jsonl.gz
rm data/spectrogram/vctk_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/spectrogram/vctk_cuts_all.jsonl.gz | wc -l) - 600 ))
lhotse subset --first $n \
data/spectrogram/vctk_cuts_all.jsonl.gz \
data/spectrogram/vctk_cuts_train.jsonl.gz
touch data/spectrogram/.vctk_split.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file"
# We assume you have installed g2p_en and espnet_tts_frontend.
# If not, please install them with:
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py \
--manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \
--tokens data/tokens.txt
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Generate speakers file"
if [ ! -e data/speakers.txt ]; then
gunzip -c data/manifests/vctk_supervisions_all.jsonl.gz \
| jq '.speaker' | sed 's/"//g' \
| sort | uniq > data/speakers.txt
fi
fi

1
egs/vctk/TTS/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/duration_predictor.py

284
egs/vctk/TTS/vits/export-onnx.py Executable file
View File

@ -0,0 +1,284 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script exports a VITS model from PyTorch to ONNX.
Export the model to ONNX:
./vits/export-onnx.py \
--epoch 1000 \
--exp-dir vits/exp \
--tokens data/tokens.txt
It will generate two files inside vits/exp:
- vits-epoch-1000.onnx
- vits-epoch-1000.int8.onnx (quantizated model)
See ./test_onnx.py for how to use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import torch
import torch.nn as nn
from onnxruntime.quantization import QuantType, quantize_dynamic
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=1000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="The experiment dir",
)
parser.add_argument(
"--speakers",
type=Path,
default=Path("data/speakers.txt"),
help="Path to speakers.txt file.",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxModel(nn.Module):
"""A wrapper for VITS generator."""
def __init__(self, model: nn.Module):
"""
Args:
model:
A VITS generator.
frame_shift:
The frame shift in samples.
"""
super().__init__()
self.model = model
def forward(
self,
tokens: torch.Tensor,
tokens_lens: torch.Tensor,
noise_scale: float = 0.667,
noise_scale_dur: float = 0.8,
speaker: int = 20,
alpha: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of VITS.inference_batch
Args:
tokens:
Input text token indexes (1, T_text)
tokens_lens:
Number of tokens of shape (1,)
noise_scale (float):
Noise scale parameter for flow.
noise_scale_dur (float):
Noise scale parameter for duration predictor.
speaker (int):
Speaker ID.
alpha (float):
Alpha parameter to control the speed of generated speech.
Returns:
Return a tuple containing:
- audio, generated wavform tensor, (B, T_wav)
"""
audio, _, _ = self.model.inference(
text=tokens,
text_lengths=tokens_lens,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
sids=speaker,
alpha=alpha,
)
return audio
def export_model_onnx(
model: nn.Module,
model_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
The exported model has one input:
- tokens, a tensor of shape (1, T_text); dtype is torch.int64
and it has one output:
- audio, a tensor of shape (1, T'); dtype is torch.float32
Args:
model:
The VITS generator.
model_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
alpha = torch.tensor([1], dtype=torch.float32)
speaker = torch.tensor([1], dtype=torch.int64)
torch.onnx.export(
model,
(tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha),
model_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"tokens",
"tokens_lens",
"noise_scale",
"noise_scale_dur",
"speaker",
"alpha",
],
output_names=["audio"],
dynamic_axes={
"tokens": {0: "N", 1: "T"},
"tokens_lens": {0: "N"},
"audio": {0: "N", 1: "T"},
"speaker": {0: "N"},
},
)
meta_data = {
"model_type": "VITS",
"version": "1",
"model_author": "k2-fsa",
"comment": "VITS generator",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=model_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
with open(args.speakers) as f:
speaker_map = {line.strip(): i for i, line in enumerate(f)}
params.num_spks = len(speaker_map)
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model = model.generator
model.to("cpu")
model.eval()
model = OnnxModel(model=model)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"generator parameters: {num_param}")
suffix = f"epoch-{params.epoch}"
opset_version = 13
logging.info("Exporting encoder")
model_filename = params.exp_dir / f"vits-{suffix}.onnx"
export_model_onnx(
model,
model_filename,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
quantize_dynamic(
model_input=model_filename,
model_output=model_filename_int8,
weight_type=QuantType.QUInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

1
egs/vctk/TTS/vits/flow.py Symbolic link
View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/flow.py

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/generator.py

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/hifigan.py

272
egs/vctk/TTS/vits/infer.py Executable file
View File

@ -0,0 +1,272 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script performs model inference on test set.
Usage:
./vits/infer.py \
--epoch 1000 \
--exp-dir ./vits/exp \
--max-duration 500
"""
import argparse
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List
import k2
import torch
import torch.nn as nn
import torchaudio
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import VctkTtsDataModule
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=1000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
return parser
def infer_dataset(
dl: torch.utils.data.DataLoader,
subset: str,
params: AttributeDict,
model: nn.Module,
tokenizer: Tokenizer,
speaker_map: Dict[str, int],
) -> None:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
tokenizer:
Used to convert text to phonemes.
"""
# Background worker save audios to disk.
def _save_worker(
subset: str,
batch_size: int,
cut_ids: List[str],
audio: torch.Tensor,
audio_pred: torch.Tensor,
audio_lens: List[int],
audio_lens_pred: List[int],
):
for i in range(batch_size):
torchaudio.save(
str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"),
audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate,
)
torchaudio.save(
str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"),
audio_pred[i : i + 1, : audio_lens_pred[i]],
sample_rate=params.sampling_rate,
)
device = next(model.parameters()).device
num_cuts = 0
log_interval = 5
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["tokens"])
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
.int()
.to(device)
)
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
audio_pred, _, durations = model.inference_batch(
text=tokens,
text_lengths=tokens_lens,
sids=speakers,
)
audio_pred = audio_pred.detach().cpu()
# convert to samples
audio_lens_pred = (
(durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
)
futures.append(
executor.submit(
_save_worker,
subset,
batch_size,
cut_ids,
audio,
audio_pred,
audio_lens,
audio_lens_pred,
)
)
num_cuts += batch_size
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
# return results
for f in futures:
f.result()
@torch.no_grad()
def main():
parser = get_parser()
VctkTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}"
params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Infer started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.blank_id
params.oov_id = tokenizer.oov_id
params.vocab_size = tokenizer.vocab_size
# we need cut ids to display recognition results.
args.return_cuts = True
vctk = VctkTtsDataModule(args)
speaker_map = vctk.speakers()
params.num_spks = len(speaker_map)
logging.info(f"Device: {device}")
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
num_param_g = sum([p.numel() for p in model.generator.parameters()])
logging.info(f"Number of parameters in generator: {num_param_g}")
num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
test_cuts = vctk.test_cuts()
test_dl = vctk.test_dataloaders(test_cuts)
valid_cuts = vctk.valid_cuts()
valid_dl = vctk.valid_dataloaders(valid_cuts)
infer_sets = {"test": test_dl, "valid": valid_dl}
for subset, dl in infer_sets.items():
save_wav_dir = params.res_dir / "wav" / subset
save_wav_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
infer_dataset(
dl=dl,
subset=subset,
params=params,
model=model,
tokenizer=tokenizer,
speaker_map=speaker_map,
)
logging.info(f"Wav files are saved to {params.save_wav_dir}")
logging.info("Done!")
if __name__ == "__main__":
main()

1
egs/vctk/TTS/vits/loss.py Symbolic link
View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/loss.py

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/monotonic_align

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/posterior_encoder.py

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/residual_coupling.py

138
egs/vctk/TTS/vits/test_onnx.py Executable file
View File

@ -0,0 +1,138 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is used to test the exported onnx model by vits/export-onnx.py
Use the onnx model to generate a wav:
./vits/test_onnx.py \
--model-filename vits/exp/vits-epoch-1000.onnx \
--tokens data/tokens.txt
"""
import argparse
import logging
from pathlib import Path
import onnxruntime as ort
import torch
import torchaudio
from tokenizer import Tokenizer
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-filename",
type=str,
required=True,
help="Path to the onnx model.",
)
parser.add_argument(
"--speakers",
type=Path,
default=Path("data/speakers.txt"),
help="Path to speakers.txt file.",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
return parser
class OnnxModel:
def __init__(self, model_filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 4
self.session_opts = session_opts
self.model = ort.InferenceSession(
model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
def __call__(
self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor
) -> torch.Tensor:
"""
Args:
tokens:
A 1-D tensor of shape (1, T)
Returns:
A tensor of shape (1, T')
"""
noise_scale = torch.tensor([0.667], dtype=torch.float32)
noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
alpha = torch.tensor([1.0], dtype=torch.float32)
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
self.model.get_inputs()[4].name: speaker.numpy(),
self.model.get_inputs()[5].name: alpha.numpy(),
},
)[0]
return torch.from_numpy(out)
def main():
args = get_parser().parse_args()
tokenizer = Tokenizer(args.tokens)
with open(args.speakers) as f:
speaker_map = {line.strip(): i for i, line in enumerate(f)}
args.num_spks = len(speaker_map)
logging.info("About to create onnx model")
model = OnnxModel(args.model_filename)
text = "I went there to see the land, the people and how their system works, end quote."
tokens = tokenizer.texts_to_token_ids([text])
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
speaker = torch.tensor([1], dtype=torch.int64) # (1, )
audio = model(tokens, tokens_lens, speaker) # (1, T')
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
logging.info("Saved to test_onnx.wav")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/text_encoder.py

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/tokenizer.py

1000
egs/vctk/TTS/vits/train.py Executable file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/transform.py

View File

@ -0,0 +1,338 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class VctkTtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/spectrogram"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--speakers",
type=Path,
default=Path("data/speakers.txt"),
help="Path to speakers.txt file.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=8,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
)
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
)
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
)
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
return_spk_ids=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
@lru_cache()
def speakers(self) -> Dict[str, int]:
logging.info("About to get speakers")
with open(self.args.speakers) as f:
speakers = {line.strip(): i for i, line in enumerate(f)}
return speakers

1
egs/vctk/TTS/vits/utils.py Symbolic link
View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/utils.py

1
egs/vctk/TTS/vits/vits.py Symbolic link
View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/vits.py

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/vits/wavenet.py

6
requirements-tts.txt Normal file
View File

@ -0,0 +1,6 @@
# for TTS recipes
matplotlib==3.8.2
cython==3.0.6
numba==0.58.1
g2p_en==2.1.0
espnet_tts_frontend==0.0.3

View File

@ -8,3 +8,5 @@ tensorboard
typeguard
dill
black==22.3.0
onnx==1.15.0
onnxruntime==1.16.3