mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
A TTS recipe VITS on VCTK dataset (#1380)
* init * isort formatted * minor updates * Create shared * Update prepare_tokens_vctk.py * Update prepare_tokens_vctk.py * Update prepare_tokens_vctk.py * Update prepare.sh * updated * Update train.py * Update train.py * Update tts_datamodule.py * Update train.py * Update train.py * Update train.py * Update train.py * Update train.py * Update train.py * fixed formatting issue * Update infer.py * removed redundant files * Create monotonic_align * removed redundant files * created symlinks * Update prepare.sh * minor adjustments * Create requirements_tts.txt * Update requirements_tts.txt added version constraints * Update infer.py * Update infer.py * Update infer.py * updated docs * Update export-onnx.py * Update export-onnx.py * Update test_onnx.py * updated requirements.txt * Update test_onnx.py * Update test_onnx.py * docs updated * docs fixed * minor updates
This commit is contained in:
parent
f08af2fa22
commit
735fb9a73d
@ -5,3 +5,4 @@ TTS
|
||||
:maxdepth: 2
|
||||
|
||||
ljspeech/vits
|
||||
vctk/vits
|
@ -4,6 +4,10 @@ VITS
|
||||
This tutorial shows you how to train an VITS model
|
||||
with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
|
||||
|
||||
.. note::
|
||||
|
||||
TTS related recipes require packages in ``requirements-tts.txt``.
|
||||
|
||||
.. note::
|
||||
|
||||
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
|
||||
@ -27,6 +31,12 @@ To run stage 1 to stage 5, use
|
||||
Build Monotonic Alignment Search
|
||||
--------------------------------
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ ./prepare.sh --stage -1 --stop_stage -1
|
||||
|
||||
or
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd vits/monotonic_align
|
||||
@ -74,7 +84,7 @@ training part first. It will save the ground-truth and generated wavs to the dir
|
||||
$ ./vits/infer.py \
|
||||
--epoch 1000 \
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt
|
||||
--tokens data/tokens.txt \
|
||||
--max-duration 500
|
||||
|
||||
.. note::
|
||||
|
125
docs/source/recipes/TTS/vctk/vits.rst
Normal file
125
docs/source/recipes/TTS/vctk/vits.rst
Normal file
@ -0,0 +1,125 @@
|
||||
VITS
|
||||
===============
|
||||
|
||||
This tutorial shows you how to train an VITS model
|
||||
with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset.
|
||||
|
||||
.. note::
|
||||
|
||||
TTS related recipes require packages in ``requirements-tts.txt``.
|
||||
|
||||
.. note::
|
||||
|
||||
The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech <https://arxiv.org/pdf/2106.06103.pdf>`_
|
||||
|
||||
|
||||
Data preparation
|
||||
----------------
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd egs/vctk/TTS
|
||||
$ ./prepare.sh
|
||||
|
||||
To run stage 1 to stage 6, use
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ ./prepare.sh --stage 1 --stop_stage 6
|
||||
|
||||
|
||||
Build Monotonic Alignment Search
|
||||
--------------------------------
|
||||
|
||||
To build the monotonic alignment search, use the following commands:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ ./prepare.sh --stage -1 --stop_stage -1
|
||||
|
||||
or
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ cd vits/monotonic_align
|
||||
$ python setup.py build_ext --inplace
|
||||
$ cd ../../
|
||||
|
||||
|
||||
Training
|
||||
--------
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
$ ./vits/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 1000 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt
|
||||
--max-duration 350
|
||||
|
||||
.. note::
|
||||
|
||||
You can adjust the hyper-parameters to control the size of the VITS model and
|
||||
the training configurations. For more details, please run ``./vits/train.py --help``.
|
||||
|
||||
.. note::
|
||||
|
||||
The training can take a long time (usually a couple of days).
|
||||
|
||||
Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``.
|
||||
|
||||
|
||||
Inference
|
||||
---------
|
||||
|
||||
The inference part uses checkpoints saved by the training part, so you have to run the
|
||||
training part first. It will save the ground-truth and generated wavs to the directory
|
||||
``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ export CUDA_VISIBLE_DEVICES="0"
|
||||
$ ./vits/infer.py \
|
||||
--epoch 1000 \
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt \
|
||||
--max-duration 500
|
||||
|
||||
.. note::
|
||||
|
||||
For more details, please run ``./vits/infer.py --help``.
|
||||
|
||||
|
||||
Export models
|
||||
-------------
|
||||
|
||||
Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
|
||||
``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ ./vits/export-onnx.py \
|
||||
--epoch 1000 \
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt
|
||||
|
||||
You can test the exported ONNX model with:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ ./vits/test_onnx.py \
|
||||
--model-filename vits/exp/vits-epoch-1000.onnx \
|
||||
--tokens data/tokens.txt
|
||||
|
||||
|
||||
Download pretrained models
|
||||
--------------------------
|
||||
|
||||
If you don't want to train from scratch, you can download the pretrained models
|
||||
by visiting the following link:
|
||||
|
||||
- `<https://huggingface.co/zrjin/icefall-tts-vctk-vits-2023-12-05>`_
|
@ -5,8 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
nj=1
|
||||
stage=-1
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
dl_dir=$PWD/download
|
||||
@ -25,6 +24,17 @@ log() {
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "Stage -1: build monotonic_align lib"
|
||||
if [ ! -d vits/monotonic_align/build ]; then
|
||||
cd vits/monotonic_align
|
||||
python setup.py build_ext --inplace
|
||||
cd ../../
|
||||
else
|
||||
log "monotonic_align lib already built"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
@ -113,5 +123,3 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
--tokens data/tokens.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
|
@ -14,7 +14,6 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from flow import (
|
||||
ConvFlow,
|
||||
DilatedDepthSeparableConv,
|
||||
|
@ -180,7 +180,13 @@ def export_model_onnx(
|
||||
model_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"],
|
||||
input_names=[
|
||||
"tokens",
|
||||
"tokens_lens",
|
||||
"noise_scale",
|
||||
"noise_scale_dur",
|
||||
"alpha",
|
||||
],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"tokens": {0: "N", 1: "T"},
|
||||
|
@ -13,7 +13,6 @@ import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from transform import piecewise_rational_quadratic_transform
|
||||
|
||||
|
||||
|
@ -16,9 +16,6 @@ from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
from duration_predictor import StochasticDurationPredictor
|
||||
from hifigan import HiFiGANGenerator
|
||||
from posterior_encoder import PosteriorEncoder
|
||||
@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock
|
||||
from text_encoder import TextEncoder
|
||||
from utils import get_random_segments
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
class VITSGenerator(torch.nn.Module):
|
||||
"""Generator module in VITS, `Conditional Variational Autoencoder
|
||||
|
@ -36,13 +36,12 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
|
||||
from train import get_model, get_params
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -144,14 +143,24 @@ def infer_dataset(
|
||||
audio_lens = batch["audio_lens"].tolist()
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
|
||||
audio_pred, _, durations = model.inference_batch(
|
||||
text=tokens, text_lengths=tokens_lens
|
||||
)
|
||||
audio_pred = audio_pred.detach().cpu()
|
||||
# convert to samples
|
||||
audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
||||
audio_lens_pred = (
|
||||
(durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
||||
)
|
||||
|
||||
futures.append(
|
||||
executor.submit(
|
||||
_save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
|
||||
_save_worker,
|
||||
batch_size,
|
||||
cut_ids,
|
||||
audio,
|
||||
audio_pred,
|
||||
audio_lens,
|
||||
audio_lens_pred,
|
||||
)
|
||||
)
|
||||
|
||||
@ -160,7 +169,9 @@ def infer_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
# return results
|
||||
for f in futures:
|
||||
f.result()
|
||||
|
@ -14,7 +14,6 @@ from typing import List, Tuple, Union
|
||||
import torch
|
||||
import torch.distributions as D
|
||||
import torch.nn.functional as F
|
||||
|
||||
from lhotse.features.kaldi import Wav2LogFilterBank
|
||||
|
||||
|
||||
|
@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits.
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from wavenet import Conv1d, WaveNet
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
from wavenet import WaveNet, Conv1d
|
||||
|
||||
|
||||
class PosteriorEncoder(torch.nn.Module):
|
||||
|
@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from flow import FlipFlow
|
||||
from wavenet import WaveNet
|
||||
|
||||
|
@ -28,10 +28,10 @@ Use the onnx model to generate a wav:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
|
||||
|
@ -169,9 +169,7 @@ class Transformer(nn.Module):
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
x = self.encoder(
|
||||
x, pos_emb, key_padding_mask=key_padding_mask
|
||||
) # (T, N, C)
|
||||
x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C)
|
||||
|
||||
x = self.after_norm(x)
|
||||
|
||||
@ -207,7 +205,9 @@ class TransformerEncoderLayer(nn.Module):
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
||||
self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, num_heads, dropout=dropout
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
@ -242,7 +242,9 @@ class TransformerEncoderLayer(nn.Module):
|
||||
key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
|
||||
"""
|
||||
# macaron style feed-forward module
|
||||
src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src)))
|
||||
src = src + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(self.norm_ff_macaron(src))
|
||||
)
|
||||
|
||||
# multi-head self-attention module
|
||||
src_attn = self.self_attn(
|
||||
@ -490,11 +492,17 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
|
||||
k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
|
||||
v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
v = (
|
||||
v.contiguous()
|
||||
.view(seq_len, batch_size * self.num_heads, self.head_dim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
|
||||
|
||||
p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim)
|
||||
p = self.linear_pos(pos_emb).view(
|
||||
pos_emb.size(0), -1, self.num_heads, self.head_dim
|
||||
)
|
||||
# (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
|
||||
p = p.permute(0, 2, 3, 1)
|
||||
|
||||
@ -506,15 +514,23 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# first compute matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len)
|
||||
matrix_ac = torch.matmul(
|
||||
q_with_bias_u, k
|
||||
) # (batch_size, num_head, seq_len, seq_len)
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1)
|
||||
matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len)
|
||||
matrix_bd = torch.matmul(
|
||||
q_with_bias_v, p
|
||||
) # (batch_size, num_head, seq_len, 2*seq_len-1)
|
||||
matrix_bd = self.rel_shift(
|
||||
matrix_bd
|
||||
) # (batch_size, num_head, seq_len, seq_len)
|
||||
|
||||
# (batch_size, num_head, seq_len, seq_len)
|
||||
attn_output_weights = (matrix_ac + matrix_bd) * scaling
|
||||
attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
batch_size * self.num_heads, seq_len, seq_len
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch_size, seq_len)
|
||||
@ -536,10 +552,16 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
# (batch_size * num_head, seq_len, head_dim)
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim)
|
||||
assert attn_output.shape == (
|
||||
batch_size * self.num_heads,
|
||||
seq_len,
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim)
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(seq_len, batch_size, self.embed_dim)
|
||||
)
|
||||
# (seq_len, batch_size, embed_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
@ -78,7 +78,9 @@ class Tokenizer(object):
|
||||
|
||||
return token_ids_list
|
||||
|
||||
def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True):
|
||||
def tokens_to_token_ids(
|
||||
self, tokens_list: List[str], intersperse_blank: bool = True
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tokens_list:
|
||||
|
@ -18,21 +18,25 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.optim import Optimizer
|
||||
from tokenizer import Tokenizer
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||
from vits import VITS
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -41,11 +45,6 @@ from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||
from vits import VITS
|
||||
|
||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||
|
||||
|
||||
@ -385,11 +384,12 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||
prepare_input(batch, tokenizer, device)
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||
batch, tokenizer, device
|
||||
)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info['samples'] = batch_size
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
@ -446,7 +446,9 @@ def train_one_epoch(
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
|
||||
if cur_grad_scale < 8.0 or (
|
||||
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
|
||||
):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
@ -482,9 +484,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
@ -492,19 +492,34 @@ def train_one_epoch(
|
||||
if "returned_sample" in stats_g:
|
||||
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
||||
tb_writer.add_audio(
|
||||
"train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
|
||||
"train/speech_hat_",
|
||||
speech_hat_,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/speech_", speech_, params.batch_idx_train, params.sampling_rate
|
||||
"train/speech_",
|
||||
speech_,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC'
|
||||
"train/mel_hat_",
|
||||
plot_feature(mel_hat_),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
|
||||
"train/mel_",
|
||||
plot_feature(mel_),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
|
||||
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
if (
|
||||
params.batch_idx_train % params.valid_interval == 0
|
||||
and not params.print_diagnostics
|
||||
):
|
||||
logging.info("Computing validation loss")
|
||||
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||
params=params,
|
||||
@ -523,10 +538,16 @@ def train_one_epoch(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate
|
||||
"train/valdi_speech_hat",
|
||||
speech_hat,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate
|
||||
"train/valdi_speech",
|
||||
speech,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
|
||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||
@ -555,11 +576,17 @@ def compute_validation_loss(
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
batch_size = len(batch["tokens"])
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||
prepare_input(batch, tokenizer, device)
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
) = prepare_input(batch, tokenizer, device)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info['samples'] = batch_size
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
# forward discriminator
|
||||
loss_d, stats_d = model(
|
||||
@ -599,8 +626,13 @@ def compute_validation_loss(
|
||||
text=tokens[0, : tokens_lens[0].item()]
|
||||
)
|
||||
audio_pred = audio_pred.data.cpu().numpy()
|
||||
audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||
assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
|
||||
audio_len_pred = (
|
||||
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||
)
|
||||
assert audio_len_pred == len(audio_pred), (
|
||||
audio_len_pred,
|
||||
len(audio_pred),
|
||||
)
|
||||
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
|
||||
returned_sample = (audio_pred, audio_gt)
|
||||
|
||||
@ -632,8 +664,9 @@ def scan_pessimistic_batches_for_oom(
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = \
|
||||
prepare_input(batch, tokenizer, device)
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||
batch, tokenizer, device
|
||||
)
|
||||
try:
|
||||
# for discriminator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
|
@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
SpeechSynthesisDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
|
@ -14,15 +14,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import collections
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from pathlib import Path
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
@ -97,23 +97,23 @@ def plot_feature(spectrogram):
|
||||
global MATPLOTLIB_FLAG
|
||||
if not MATPLOTLIB_FLAG:
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg")
|
||||
MATPLOTLIB_FLAG = True
|
||||
mpl_logger = logging.getLogger('matplotlib')
|
||||
mpl_logger = logging.getLogger("matplotlib")
|
||||
mpl_logger.setLevel(logging.WARNING)
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||
interpolation='none')
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.xlabel("Frames")
|
||||
plt.ylabel("Channels")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
||||
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
||||
plt.close()
|
||||
return data
|
||||
|
@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from generator import VITSGenerator
|
||||
from hifigan import (
|
||||
HiFiGANMultiPeriodDiscriminator,
|
||||
HiFiGANMultiScaleDiscriminator,
|
||||
@ -25,9 +24,8 @@ from loss import (
|
||||
KLDivergenceLoss,
|
||||
MelSpectrogramLoss,
|
||||
)
|
||||
from torch.cuda.amp import autocast
|
||||
from utils import get_segments
|
||||
from generator import VITSGenerator
|
||||
|
||||
|
||||
AVAILABLE_GENERATERS = {
|
||||
"vits_generator": VITSGenerator,
|
||||
@ -42,8 +40,7 @@ AVAILABLE_DISCRIMINATORS = {
|
||||
|
||||
|
||||
class VITS(nn.Module):
|
||||
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`
|
||||
"""
|
||||
"""Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
import logging
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
107
egs/vctk/TTS/local/compute_spectrogram_vctk.py
Executable file
107
egs/vctk/TTS/local/compute_spectrogram_vctk.py
Executable file
@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the VCTK dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/spectrogram.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
LilcomChunkyWriter,
|
||||
Spectrogram,
|
||||
SpectrogramConfig,
|
||||
load_manifest,
|
||||
)
|
||||
from lhotse.audio import RecordingSet
|
||||
from lhotse.supervision import SupervisionSet
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_spectrogram_vctk():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/spectrogram")
|
||||
num_jobs = min(32, os.cpu_count())
|
||||
|
||||
sampling_rate = 22050
|
||||
frame_length = 1024 / sampling_rate # (in second)
|
||||
frame_shift = 256 / sampling_rate # (in second)
|
||||
use_fft_mag = True
|
||||
|
||||
prefix = "vctk"
|
||||
suffix = "jsonl.gz"
|
||||
partition = "all"
|
||||
|
||||
recordings = load_manifest(
|
||||
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
|
||||
).resample(sampling_rate=sampling_rate)
|
||||
supervisions = load_manifest(
|
||||
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
|
||||
)
|
||||
|
||||
config = SpectrogramConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
use_fft_mag=use_fft_mag,
|
||||
)
|
||||
extractor = Spectrogram(config)
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
return
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=recordings, supervisions=supervisions
|
||||
)
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
compute_spectrogram_vctk()
|
83
egs/vctk/TTS/local/display_manifest_statistics.py
Executable file
83
egs/vctk/TTS/local/display_manifest_statistics.py
Executable file
@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file displays duration statistics of utterances in a manifest.
|
||||
You can use the displayed value to choose minimum/maximum duration
|
||||
to remove short and long utterances during the training.
|
||||
|
||||
See the function `remove_short_and_long_utt()` in vits/train.py
|
||||
for usage.
|
||||
"""
|
||||
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
|
||||
|
||||
def main():
|
||||
path = "./data/spectrogram/vctk_cuts_all.jsonl.gz"
|
||||
cuts = load_manifest_lazy(path)
|
||||
cuts.describe()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
Cut statistics:
|
||||
╒═══════════════════════════╤══════════╕
|
||||
│ Cuts count: │ 43873 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Total duration (hh:mm:ss) │ 41:02:18 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ mean │ 3.4 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ std │ 1.2 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ min │ 1.2 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 25% │ 2.6 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 50% │ 3.1 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 75% │ 3.8 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99% │ 8.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99.5% │ 9.1 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99.9% │ 12.1 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ max │ 16.6 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Recordings available: │ 43873 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Features available: │ 43873 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Supervisions available: │ 43873 │
|
||||
╘═══════════════════════════╧══════════╛
|
||||
SUPERVISION custom fields:
|
||||
Speech duration statistics:
|
||||
╒══════════════════════════════╤══════════╤══════════════════════╕
|
||||
│ Total speech duration │ 41:02:18 │ 100.00% of recording │
|
||||
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||
│ Total speaking time duration │ 41:02:18 │ 100.00% of recording │
|
||||
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||
│ Total silence duration │ 00:00:01 │ 0.00% of recording │
|
||||
╘══════════════════════════════╧══════════╧══════════════════════╛
|
||||
"""
|
104
egs/vctk/TTS/local/prepare_token_file.py
Executable file
104
egs/vctk/TTS/local/prepare_token_file.py
Executable file
@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file reads the texts in given manifest and generates the file that maps tokens to IDs.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from lhotse import load_manifest
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--manifest-file",
|
||||
type=Path,
|
||||
default=Path("data/spectrogram/vctk_cuts_all.jsonl.gz"),
|
||||
help="Path to the manifest file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=Path,
|
||||
default=Path("data/tokens.txt"),
|
||||
help="Path to the tokens",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||
"""Write a symbol to ID mapping to a file.
|
||||
|
||||
Note:
|
||||
No need to implement `read_mapping` as it can be done
|
||||
through :func:`k2.SymbolTable.from_file`.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename to save the mapping.
|
||||
sym2id:
|
||||
A dict mapping symbols to IDs.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for sym, i in sym2id.items():
|
||||
f.write(f"{sym} {i}\n")
|
||||
|
||||
|
||||
def get_token2id(manifest_file: Path) -> Dict[str, int]:
|
||||
"""Return a dict that maps token to IDs."""
|
||||
extra_tokens = [
|
||||
"<blk>", # 0 for blank
|
||||
"<sos/eos>", # 1 for sos and eos symbols.
|
||||
"<unk>", # 2 for OOV
|
||||
]
|
||||
all_tokens = set()
|
||||
|
||||
cut_set = load_manifest(manifest_file)
|
||||
|
||||
for cut in cut_set:
|
||||
# Each cut only contain one supervision
|
||||
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
||||
for t in cut.tokens:
|
||||
all_tokens.add(t)
|
||||
|
||||
all_tokens = extra_tokens + list(all_tokens)
|
||||
|
||||
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
|
||||
return token2id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
manifest_file = Path(args.manifest_file)
|
||||
out_file = Path(args.tokens)
|
||||
|
||||
token2id = get_token2id(manifest_file)
|
||||
write_mapping(out_file, token2id)
|
61
egs/vctk/TTS/local/prepare_tokens_vctk.py
Executable file
61
egs/vctk/TTS/local/prepare_tokens_vctk.py
Executable file
@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file reads the texts in given manifest and save the new cuts with phoneme tokens.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import g2p_en
|
||||
import tacotron_cleaner.cleaners
|
||||
from lhotse import CutSet, load_manifest
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def prepare_tokens_vctk():
|
||||
output_dir = Path("data/spectrogram")
|
||||
prefix = "vctk"
|
||||
suffix = "jsonl.gz"
|
||||
partition = "all"
|
||||
|
||||
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||
g2p = g2p_en.G2p()
|
||||
|
||||
new_cuts = []
|
||||
for cut in tqdm(cut_set):
|
||||
# Each cut only contains one supervision
|
||||
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
||||
text = cut.supervisions[0].text
|
||||
# Text normalization
|
||||
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
|
||||
# Convert to phonemes
|
||||
cut.tokens = g2p(text)
|
||||
new_cuts.append(cut)
|
||||
|
||||
new_cut_set = CutSet.from_cuts(new_cuts)
|
||||
new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
prepare_tokens_vctk()
|
70
egs/vctk/TTS/local/validate_manifest.py
Executable file
70
egs/vctk/TTS/local/validate_manifest.py
Executable file
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script checks the following assumptions of the generated manifest:
|
||||
|
||||
- Single supervision per cut
|
||||
|
||||
We will add more checks later if needed.
|
||||
|
||||
Usage example:
|
||||
|
||||
python3 ./local/validate_manifest.py \
|
||||
./data/spectrogram/ljspeech_cuts_all.jsonl.gz
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset.speech_synthesis import validate_for_tts
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"manifest",
|
||||
type=Path,
|
||||
help="Path to the manifest file",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
manifest = args.manifest
|
||||
logging.info(f"Validating {manifest}")
|
||||
|
||||
assert manifest.is_file(), f"{manifest} does not exist"
|
||||
cut_set = load_manifest_lazy(manifest)
|
||||
assert isinstance(cut_set, CutSet)
|
||||
|
||||
validate_for_tts(cut_set)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
131
egs/vctk/TTS/prepare.sh
Executable file
131
egs/vctk/TTS/prepare.sh
Executable file
@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "Stage -1: build monotonic_align lib"
|
||||
if [ ! -d vits/monotonic_align/build ]; then
|
||||
cd vits/monotonic_align
|
||||
python setup.py build_ext --inplace
|
||||
cd ../../
|
||||
else
|
||||
log "monotonic_align lib already built"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/VCTK,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/VCTK $dl_dir/VCTK
|
||||
#
|
||||
if [ ! -d $dl_dir/VCTK ]; then
|
||||
lhotse download vctk $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare VCTK manifest"
|
||||
# We assume that you have downloaded the VCTK corpus
|
||||
# to $dl_dir/VCTK
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.vctk.done ]; then
|
||||
lhotse prepare vctk --use-edinburgh-vctk-url true $dl_dir/VCTK data/manifests
|
||||
touch data/manifests/.vctk.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Compute spectrogram for VCTK"
|
||||
mkdir -p data/spectrogram
|
||||
if [ ! -e data/spectrogram/.vctk.done ]; then
|
||||
./local/compute_spectrogram_vctk.py
|
||||
touch data/spectrogram/.vctk.done
|
||||
fi
|
||||
|
||||
if [ ! -e data/spectrogram/.vctk-validated.done ]; then
|
||||
log "Validating data/fbank for VCTK"
|
||||
./local/validate_manifest.py \
|
||||
data/spectrogram/vctk_cuts_all.jsonl.gz
|
||||
touch data/spectrogram/.vctk-validated.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare phoneme tokens for VCTK"
|
||||
if [ ! -e data/spectrogram/.vctk_with_token.done ]; then
|
||||
./local/prepare_tokens_vctk.py
|
||||
mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \
|
||||
data/spectrogram/vctk_cuts_all.jsonl.gz
|
||||
touch data/spectrogram/.vctk_with_token.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Split the VCTK cuts into train, valid and test sets"
|
||||
if [ ! -e data/spectrogram/.vctk_split.done ]; then
|
||||
lhotse subset --last 600 \
|
||||
data/spectrogram/vctk_cuts_all.jsonl.gz \
|
||||
data/spectrogram/vctk_cuts_validtest.jsonl.gz
|
||||
lhotse subset --first 100 \
|
||||
data/spectrogram/vctk_cuts_validtest.jsonl.gz \
|
||||
data/spectrogram/vctk_cuts_valid.jsonl.gz
|
||||
lhotse subset --last 500 \
|
||||
data/spectrogram/vctk_cuts_validtest.jsonl.gz \
|
||||
data/spectrogram/vctk_cuts_test.jsonl.gz
|
||||
|
||||
rm data/spectrogram/vctk_cuts_validtest.jsonl.gz
|
||||
|
||||
n=$(( $(gunzip -c data/spectrogram/vctk_cuts_all.jsonl.gz | wc -l) - 600 ))
|
||||
lhotse subset --first $n \
|
||||
data/spectrogram/vctk_cuts_all.jsonl.gz \
|
||||
data/spectrogram/vctk_cuts_train.jsonl.gz
|
||||
touch data/spectrogram/.vctk_split.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Generate token file"
|
||||
# We assume you have installed g2p_en and espnet_tts_frontend.
|
||||
# If not, please install them with:
|
||||
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
|
||||
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
|
||||
if [ ! -e data/tokens.txt ]; then
|
||||
./local/prepare_token_file.py \
|
||||
--manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \
|
||||
--tokens data/tokens.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Generate speakers file"
|
||||
if [ ! -e data/speakers.txt ]; then
|
||||
gunzip -c data/manifests/vctk_supervisions_all.jsonl.gz \
|
||||
| jq '.speaker' | sed 's/"//g' \
|
||||
| sort | uniq > data/speakers.txt
|
||||
fi
|
||||
fi
|
1
egs/vctk/TTS/shared
Symbolic link
1
egs/vctk/TTS/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared/
|
1
egs/vctk/TTS/vits/duration_predictor.py
Symbolic link
1
egs/vctk/TTS/vits/duration_predictor.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/duration_predictor.py
|
284
egs/vctk/TTS/vits/export-onnx.py
Executable file
284
egs/vctk/TTS/vits/export-onnx.py
Executable file
@ -0,0 +1,284 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script exports a VITS model from PyTorch to ONNX.
|
||||
|
||||
Export the model to ONNX:
|
||||
./vits/export-onnx.py \
|
||||
--epoch 1000 \
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt
|
||||
|
||||
It will generate two files inside vits/exp:
|
||||
- vits-epoch-1000.onnx
|
||||
- vits-epoch-1000.int8.onnx (quantizated model)
|
||||
|
||||
See ./test_onnx.py for how to use the exported ONNX models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="vits/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--speakers",
|
||||
type=Path,
|
||||
default=Path("data/speakers.txt"),
|
||||
help="Path to speakers.txt file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/tokens.txt",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = value
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class OnnxModel(nn.Module):
|
||||
"""A wrapper for VITS generator."""
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
"""
|
||||
Args:
|
||||
model:
|
||||
A VITS generator.
|
||||
frame_shift:
|
||||
The frame shift in samples.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
tokens_lens: torch.Tensor,
|
||||
noise_scale: float = 0.667,
|
||||
noise_scale_dur: float = 0.8,
|
||||
speaker: int = 20,
|
||||
alpha: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Please see the help information of VITS.inference_batch
|
||||
|
||||
Args:
|
||||
tokens:
|
||||
Input text token indexes (1, T_text)
|
||||
tokens_lens:
|
||||
Number of tokens of shape (1,)
|
||||
noise_scale (float):
|
||||
Noise scale parameter for flow.
|
||||
noise_scale_dur (float):
|
||||
Noise scale parameter for duration predictor.
|
||||
speaker (int):
|
||||
Speaker ID.
|
||||
alpha (float):
|
||||
Alpha parameter to control the speed of generated speech.
|
||||
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- audio, generated wavform tensor, (B, T_wav)
|
||||
"""
|
||||
audio, _, _ = self.model.inference(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
noise_scale=noise_scale,
|
||||
noise_scale_dur=noise_scale_dur,
|
||||
sids=speaker,
|
||||
alpha=alpha,
|
||||
)
|
||||
return audio
|
||||
|
||||
|
||||
def export_model_onnx(
|
||||
model: nn.Module,
|
||||
model_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given generator model to ONNX format.
|
||||
The exported model has one input:
|
||||
|
||||
- tokens, a tensor of shape (1, T_text); dtype is torch.int64
|
||||
|
||||
and it has one output:
|
||||
|
||||
- audio, a tensor of shape (1, T'); dtype is torch.float32
|
||||
|
||||
Args:
|
||||
model:
|
||||
The VITS generator.
|
||||
model_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
|
||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
|
||||
noise_scale = torch.tensor([1], dtype=torch.float32)
|
||||
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
||||
alpha = torch.tensor([1], dtype=torch.float32)
|
||||
speaker = torch.tensor([1], dtype=torch.int64)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha),
|
||||
model_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"tokens",
|
||||
"tokens_lens",
|
||||
"noise_scale",
|
||||
"noise_scale_dur",
|
||||
"speaker",
|
||||
"alpha",
|
||||
],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"tokens": {0: "N", 1: "T"},
|
||||
"tokens_lens": {0: "N"},
|
||||
"audio": {0: "N", 1: "T"},
|
||||
"speaker": {0: "N"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "VITS",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "VITS generator",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
add_meta_data(filename=model_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.blank_id
|
||||
params.oov_id = tokenizer.oov_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
|
||||
with open(args.speakers) as f:
|
||||
speaker_map = {line.strip(): i for i, line in enumerate(f)}
|
||||
params.num_spks = len(speaker_map)
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
|
||||
model = model.generator
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
model = OnnxModel(model=model)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"generator parameters: {num_param}")
|
||||
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
model_filename = params.exp_dir / f"vits-{suffix}.onnx"
|
||||
export_model_onnx(
|
||||
model,
|
||||
model_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported generator to {model_filename}")
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
logging.info("Generate int8 quantization models")
|
||||
|
||||
model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=model_filename,
|
||||
model_output=model_filename_int8,
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/vctk/TTS/vits/flow.py
Symbolic link
1
egs/vctk/TTS/vits/flow.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/flow.py
|
1
egs/vctk/TTS/vits/generator.py
Symbolic link
1
egs/vctk/TTS/vits/generator.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/generator.py
|
1
egs/vctk/TTS/vits/hifigan.py
Symbolic link
1
egs/vctk/TTS/vits/hifigan.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/hifigan.py
|
272
egs/vctk/TTS/vits/infer.py
Executable file
272
egs/vctk/TTS/vits/infer.py
Executable file
@ -0,0 +1,272 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao,
|
||||
# Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script performs model inference on test set.
|
||||
|
||||
Usage:
|
||||
./vits/infer.py \
|
||||
--epoch 1000 \
|
||||
--exp-dir ./vits/exp \
|
||||
--max-duration 500
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
from tts_datamodule import VctkTtsDataModule
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="vits/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/tokens.txt",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def infer_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
subset: str,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
tokenizer: Tokenizer,
|
||||
speaker_map: Dict[str, int],
|
||||
) -> None:
|
||||
"""Decode dataset.
|
||||
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
tokenizer:
|
||||
Used to convert text to phonemes.
|
||||
"""
|
||||
|
||||
# Background worker save audios to disk.
|
||||
def _save_worker(
|
||||
subset: str,
|
||||
batch_size: int,
|
||||
cut_ids: List[str],
|
||||
audio: torch.Tensor,
|
||||
audio_pred: torch.Tensor,
|
||||
audio_lens: List[int],
|
||||
audio_lens_pred: List[int],
|
||||
):
|
||||
for i in range(batch_size):
|
||||
torchaudio.save(
|
||||
str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"),
|
||||
audio[i : i + 1, : audio_lens[i]],
|
||||
sample_rate=params.sampling_rate,
|
||||
)
|
||||
torchaudio.save(
|
||||
str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"),
|
||||
audio_pred[i : i + 1, : audio_lens_pred[i]],
|
||||
sample_rate=params.sampling_rate,
|
||||
)
|
||||
|
||||
device = next(model.parameters()).device
|
||||
num_cuts = 0
|
||||
log_interval = 5
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
futures = []
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
batch_size = len(batch["tokens"])
|
||||
|
||||
tokens = batch["tokens"]
|
||||
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||
tokens = k2.RaggedTensor(tokens)
|
||||
row_splits = tokens.shape.row_splits(1)
|
||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||
tokens = tokens.to(device)
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
# tensor of shape (B, T)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||
speakers = (
|
||||
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
|
||||
.int()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
audio = batch["audio"]
|
||||
audio_lens = batch["audio_lens"].tolist()
|
||||
cut_ids = [cut.id for cut in batch["cut"]]
|
||||
|
||||
audio_pred, _, durations = model.inference_batch(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
sids=speakers,
|
||||
)
|
||||
audio_pred = audio_pred.detach().cpu()
|
||||
# convert to samples
|
||||
audio_lens_pred = (
|
||||
(durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
|
||||
)
|
||||
|
||||
futures.append(
|
||||
executor.submit(
|
||||
_save_worker,
|
||||
subset,
|
||||
batch_size,
|
||||
cut_ids,
|
||||
audio,
|
||||
audio_pred,
|
||||
audio_lens,
|
||||
audio_lens_pred,
|
||||
)
|
||||
)
|
||||
|
||||
num_cuts += batch_size
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
# return results
|
||||
for f in futures:
|
||||
f.result()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
VctkTtsDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}"
|
||||
|
||||
params.res_dir = params.exp_dir / "infer" / params.suffix
|
||||
params.save_wav_dir = params.res_dir / "wav"
|
||||
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
|
||||
logging.info("Infer started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.blank_id
|
||||
params.oov_id = tokenizer.oov_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
vctk = VctkTtsDataModule(args)
|
||||
speaker_map = vctk.speakers()
|
||||
params.num_spks = len(speaker_map)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param_g = sum([p.numel() for p in model.generator.parameters()])
|
||||
logging.info(f"Number of parameters in generator: {num_param_g}")
|
||||
num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
|
||||
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
||||
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
||||
|
||||
test_cuts = vctk.test_cuts()
|
||||
test_dl = vctk.test_dataloaders(test_cuts)
|
||||
|
||||
valid_cuts = vctk.valid_cuts()
|
||||
valid_dl = vctk.valid_dataloaders(valid_cuts)
|
||||
|
||||
infer_sets = {"test": test_dl, "valid": valid_dl}
|
||||
|
||||
for subset, dl in infer_sets.items():
|
||||
save_wav_dir = params.res_dir / "wav" / subset
|
||||
save_wav_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"Processing {subset} set, saving to {save_wav_dir}")
|
||||
|
||||
infer_dataset(
|
||||
dl=dl,
|
||||
subset=subset,
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
speaker_map=speaker_map,
|
||||
)
|
||||
|
||||
logging.info(f"Wav files are saved to {params.save_wav_dir}")
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/vctk/TTS/vits/loss.py
Symbolic link
1
egs/vctk/TTS/vits/loss.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/loss.py
|
1
egs/vctk/TTS/vits/monotonic_align
Symbolic link
1
egs/vctk/TTS/vits/monotonic_align
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/monotonic_align
|
1
egs/vctk/TTS/vits/posterior_encoder.py
Symbolic link
1
egs/vctk/TTS/vits/posterior_encoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/posterior_encoder.py
|
1
egs/vctk/TTS/vits/residual_coupling.py
Symbolic link
1
egs/vctk/TTS/vits/residual_coupling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/residual_coupling.py
|
138
egs/vctk/TTS/vits/test_onnx.py
Executable file
138
egs/vctk/TTS/vits/test_onnx.py
Executable file
@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is used to test the exported onnx model by vits/export-onnx.py
|
||||
|
||||
Use the onnx model to generate a wav:
|
||||
./vits/test_onnx.py \
|
||||
--model-filename vits/exp/vits-epoch-1000.onnx \
|
||||
--tokens data/tokens.txt
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the onnx model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--speakers",
|
||||
type=Path,
|
||||
default=Path("data/speakers.txt"),
|
||||
help="Path to speakers.txt file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/tokens.txt",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(self, model_filename: str):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 4
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.model = ort.InferenceSession(
|
||||
model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||
|
||||
def __call__(
|
||||
self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
tokens:
|
||||
A 1-D tensor of shape (1, T)
|
||||
Returns:
|
||||
A tensor of shape (1, T')
|
||||
"""
|
||||
noise_scale = torch.tensor([0.667], dtype=torch.float32)
|
||||
noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
|
||||
alpha = torch.tensor([1.0], dtype=torch.float32)
|
||||
|
||||
out = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: tokens.numpy(),
|
||||
self.model.get_inputs()[1].name: tokens_lens.numpy(),
|
||||
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
||||
self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
|
||||
self.model.get_inputs()[4].name: speaker.numpy(),
|
||||
self.model.get_inputs()[5].name: alpha.numpy(),
|
||||
},
|
||||
)[0]
|
||||
return torch.from_numpy(out)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
|
||||
tokenizer = Tokenizer(args.tokens)
|
||||
|
||||
with open(args.speakers) as f:
|
||||
speaker_map = {line.strip(): i for i, line in enumerate(f)}
|
||||
args.num_spks = len(speaker_map)
|
||||
|
||||
logging.info("About to create onnx model")
|
||||
model = OnnxModel(args.model_filename)
|
||||
|
||||
text = "I went there to see the land, the people and how their system works, end quote."
|
||||
tokens = tokenizer.texts_to_token_ids([text])
|
||||
tokens = torch.tensor(tokens) # (1, T)
|
||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
||||
speaker = torch.tensor([1], dtype=torch.int64) # (1, )
|
||||
audio = model(tokens, tokens_lens, speaker) # (1, T')
|
||||
|
||||
torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
|
||||
logging.info("Saved to test_onnx.wav")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/vctk/TTS/vits/text_encoder.py
Symbolic link
1
egs/vctk/TTS/vits/text_encoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/text_encoder.py
|
1
egs/vctk/TTS/vits/tokenizer.py
Symbolic link
1
egs/vctk/TTS/vits/tokenizer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/tokenizer.py
|
1000
egs/vctk/TTS/vits/train.py
Executable file
1000
egs/vctk/TTS/vits/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/vctk/TTS/vits/transform.py
Symbolic link
1
egs/vctk/TTS/vits/transform.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/transform.py
|
338
egs/vctk/TTS/vits/tts_datamodule.py
Normal file
338
egs/vctk/TTS/vits/tts_datamodule.py
Normal file
@ -0,0 +1,338 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class VctkTtsDataModule:
|
||||
"""
|
||||
DataModule for tts experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="TTS data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/spectrogram"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--speakers",
|
||||
type=Path,
|
||||
default=Path("data/speakers.txt"),
|
||||
help="Path to speakers.txt file.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = SpectrogramConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=1024 / sampling_rate, # (in second),
|
||||
frame_shift=256 / sampling_rate, # (in second)
|
||||
use_fft_mag=True,
|
||||
)
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = SpectrogramConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=1024 / sampling_rate, # (in second),
|
||||
frame_shift=256 / sampling_rate, # (in second)
|
||||
use_fft_mag=True,
|
||||
)
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create valid dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = SpectrogramConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=1024 / sampling_rate, # (in second),
|
||||
frame_shift=256 / sampling_rate, # (in second)
|
||||
use_fft_mag=True,
|
||||
)
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
return_spk_ids=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
test_sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=test_sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get validation cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def speakers(self) -> Dict[str, int]:
|
||||
logging.info("About to get speakers")
|
||||
with open(self.args.speakers) as f:
|
||||
speakers = {line.strip(): i for i, line in enumerate(f)}
|
||||
return speakers
|
1
egs/vctk/TTS/vits/utils.py
Symbolic link
1
egs/vctk/TTS/vits/utils.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/utils.py
|
1
egs/vctk/TTS/vits/vits.py
Symbolic link
1
egs/vctk/TTS/vits/vits.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/vits.py
|
1
egs/vctk/TTS/vits/wavenet.py
Symbolic link
1
egs/vctk/TTS/vits/wavenet.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../ljspeech/TTS/vits/wavenet.py
|
6
requirements-tts.txt
Normal file
6
requirements-tts.txt
Normal file
@ -0,0 +1,6 @@
|
||||
# for TTS recipes
|
||||
matplotlib==3.8.2
|
||||
cython==3.0.6
|
||||
numba==0.58.1
|
||||
g2p_en==2.1.0
|
||||
espnet_tts_frontend==0.0.3
|
@ -8,3 +8,5 @@ tensorboard
|
||||
typeguard
|
||||
dill
|
||||
black==22.3.0
|
||||
onnx==1.15.0
|
||||
onnxruntime==1.16.3
|
Loading…
x
Reference in New Issue
Block a user