diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py index c29a28479..1a8190014 100644 --- a/egs/ljspeech/TTS/vits/duration_predictor.py +++ b/egs/ljspeech/TTS/vits/duration_predictor.py @@ -14,7 +14,6 @@ from typing import Optional import torch import torch.nn.functional as F - from flow import ( ConvFlow, DilatedDepthSeparableConv, diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 154de4bf4..2068adeea 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -180,7 +180,13 @@ def export_model_onnx( model_filename, verbose=False, opset_version=opset_version, - input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"], + input_names=[ + "tokens", + "tokens_lens", + "noise_scale", + "noise_scale_dur", + "alpha", + ], output_names=["audio"], dynamic_axes={ "tokens": {0: "N", 1: "T"}, diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py index 206bd5e3e..2b84f6434 100644 --- a/egs/ljspeech/TTS/vits/flow.py +++ b/egs/ljspeech/TTS/vits/flow.py @@ -13,7 +13,6 @@ import math from typing import Optional, Tuple, Union import torch - from transform import piecewise_rational_quadratic_transform diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py index efb0e254c..66c8cedb1 100644 --- a/egs/ljspeech/TTS/vits/generator.py +++ b/egs/ljspeech/TTS/vits/generator.py @@ -16,9 +16,6 @@ from typing import List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F - -from icefall.utils import make_pad_mask - from duration_predictor import StochasticDurationPredictor from hifigan import HiFiGANGenerator from posterior_encoder import PosteriorEncoder @@ -26,6 +23,8 @@ from residual_coupling import ResidualAffineCouplingBlock from text_encoder import TextEncoder from utils import get_random_segments +from icefall.utils import make_pad_mask + class VITSGenerator(torch.nn.Module): """Generator module in VITS, `Conditional Variational Autoencoder diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 91a35e360..cf0d20ae2 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -36,13 +36,12 @@ import k2 import torch import torch.nn as nn import torchaudio - -from train import get_model, get_params from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import LJSpeechTtsDataModule from icefall.checkpoint import load_checkpoint from icefall.utils import AttributeDict, setup_logger -from tts_datamodule import LJSpeechTtsDataModule def get_parser(): @@ -107,12 +106,12 @@ def infer_dataset( for i in range(batch_size): torchaudio.save( str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), - audio[i:i + 1, :audio_lens[i]], + audio[i : i + 1, : audio_lens[i]], sample_rate=params.sampling_rate, ) torchaudio.save( str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), - audio_pred[i:i + 1, :audio_lens_pred[i]], + audio_pred[i : i + 1, : audio_lens_pred[i]], sample_rate=params.sampling_rate, ) @@ -144,14 +143,24 @@ def infer_dataset( audio_lens = batch["audio_lens"].tolist() cut_ids = [cut.id for cut in batch["cut"]] - audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens) + audio_pred, _, durations = model.inference_batch( + text=tokens, text_lengths=tokens_lens + ) audio_pred = audio_pred.detach().cpu() # convert to samples - audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + audio_lens_pred = ( + (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + ) futures.append( executor.submit( - _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred + _save_worker, + batch_size, + cut_ids, + audio, + audio_pred, + audio_lens, + audio_lens_pred, ) ) @@ -160,7 +169,9 @@ def infer_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) # return results for f in futures: f.result() diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py index 21aaad6e7..2f4dc9bc0 100644 --- a/egs/ljspeech/TTS/vits/loss.py +++ b/egs/ljspeech/TTS/vits/loss.py @@ -14,7 +14,6 @@ from typing import List, Tuple, Union import torch import torch.distributions as D import torch.nn.functional as F - from lhotse.features.kaldi import Wav2LogFilterBank diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py index 6b8a5be52..1104fb864 100644 --- a/egs/ljspeech/TTS/vits/posterior_encoder.py +++ b/egs/ljspeech/TTS/vits/posterior_encoder.py @@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits. from typing import Optional, Tuple import torch +from wavenet import Conv1d, WaveNet from icefall.utils import make_pad_mask -from wavenet import WaveNet, Conv1d class PosteriorEncoder(torch.nn.Module): diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py index 2d6807cb7..f9a2a3786 100644 --- a/egs/ljspeech/TTS/vits/residual_coupling.py +++ b/egs/ljspeech/TTS/vits/residual_coupling.py @@ -12,7 +12,6 @@ This code is based on https://github.com/jaywalnut310/vits. from typing import Optional, Tuple, Union import torch - from flow import FlipFlow from wavenet import WaveNet diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index 8acca7c02..686fee2a0 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -28,10 +28,10 @@ Use the onnx model to generate a wav: import argparse import logging + import onnxruntime as ort import torch import torchaudio - from tokenizer import Tokenizer diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py index 9f337e45b..fcbae7103 100644 --- a/egs/ljspeech/TTS/vits/text_encoder.py +++ b/egs/ljspeech/TTS/vits/text_encoder.py @@ -169,9 +169,7 @@ class Transformer(nn.Module): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - x = self.encoder( - x, pos_emb, key_padding_mask=key_padding_mask - ) # (T, N, C) + x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C) x = self.after_norm(x) @@ -207,7 +205,9 @@ class TransformerEncoderLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) + self.self_attn = RelPositionMultiheadAttention( + d_model, num_heads, dropout=dropout + ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -242,7 +242,9 @@ class TransformerEncoderLayer(nn.Module): key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) """ # macaron style feed-forward module - src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src))) + src = src + self.ff_scale * self.dropout( + self.feed_forward_macaron(self.norm_ff_macaron(src)) + ) # multi-head self-attention module src_attn = self.self_attn( @@ -490,11 +492,17 @@ class RelPositionMultiheadAttention(nn.Module): q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) - v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + v = ( + v.contiguous() + .view(seq_len, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim) - p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim) + p = self.linear_pos(pos_emb).view( + pos_emb.size(0), -1, self.num_heads, self.head_dim + ) # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1) p = p.permute(0, 2, 3, 1) @@ -506,15 +514,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch_size, num_head, seq_len, seq_len) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1) - matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch_size, num_head, seq_len, 2*seq_len-1) + matrix_bd = self.rel_shift( + matrix_bd + ) # (batch_size, num_head, seq_len, seq_len) # (batch_size, num_head, seq_len, seq_len) attn_output_weights = (matrix_ac + matrix_bd) * scaling - attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len) + attn_output_weights = attn_output_weights.view( + batch_size * self.num_heads, seq_len, seq_len + ) if key_padding_mask is not None: assert key_padding_mask.shape == (batch_size, seq_len) @@ -536,10 +552,16 @@ class RelPositionMultiheadAttention(nn.Module): # (batch_size * num_head, seq_len, head_dim) attn_output = torch.bmm(attn_output_weights, v) - assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim) + assert attn_output.shape == ( + batch_size * self.num_heads, + seq_len, + self.head_dim, + ) attn_output = ( - attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(seq_len, batch_size, self.embed_dim) ) # (seq_len, batch_size, embed_dim) attn_output = self.out_proj(attn_output) diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 0678b26fe..70f1240b4 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -78,7 +78,9 @@ class Tokenizer(object): return token_ids_list - def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True): + def tokens_to_token_ids( + self, tokens_list: List[str], intersperse_blank: bool = True + ): """ Args: tokens_list: diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index eb43a4cc9..71c4224fa 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -18,21 +18,25 @@ import argparse import logging -import numpy as np from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union import k2 +import numpy as np import torch import torch.multiprocessing as mp import torch.nn as nn from lhotse.cut import Cut from lhotse.utils import fix_random_seed -from torch.optim import Optimizer +from tokenizer import Tokenizer from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LJSpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS from icefall import diagnostics from icefall.checkpoint import load_checkpoint @@ -41,11 +45,6 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, setup_logger, str2bool -from tokenizer import Tokenizer -from tts_datamodule import LJSpeechTtsDataModule -from utils import MetricsTracker, plot_feature, save_checkpoint -from vits import VITS - LRSchedulerType = torch.optim.lr_scheduler._LRScheduler @@ -385,11 +384,12 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["tokens"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = \ - prepare_input(batch, tokenizer, device) + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) loss_info = MetricsTracker() - loss_info['samples'] = batch_size + loss_info["samples"] = batch_size try: with autocast(enabled=params.use_fp16): @@ -446,7 +446,9 @@ def train_one_epoch( # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: @@ -482,9 +484,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train @@ -492,19 +492,34 @@ def train_one_epoch( if "returned_sample" in stats_g: speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] tb_writer.add_audio( - "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, ) tb_writer.add_audio( - "train/speech_", speech_, params.batch_idx_train, params.sampling_rate + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, ) tb_writer.add_image( - "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC' + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", ) tb_writer.add_image( - "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC' + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", ) - if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info, (speech_hat, speech) = compute_validation_loss( params=params, @@ -523,10 +538,16 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) tb_writer.add_audio( - "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate + "train/valdi_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, ) tb_writer.add_audio( - "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate + "train/valdi_speech", + speech, + params.batch_idx_train, + params.sampling_rate, ) loss_value = tot_loss["generator_loss"] / tot_loss["samples"] @@ -555,11 +576,17 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): batch_size = len(batch["tokens"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = \ - prepare_input(batch, tokenizer, device) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device) loss_info = MetricsTracker() - loss_info['samples'] = batch_size + loss_info["samples"] = batch_size # forward discriminator loss_d, stats_d = model( @@ -596,12 +623,17 @@ def compute_validation_loss( if batch_idx == 0 and rank == 0: inner_model = model.module if isinstance(model, DDP) else model audio_pred, _, duration = inner_model.inference( - text=tokens[0, :tokens_lens[0].item()] + text=tokens[0, : tokens_lens[0].item()] ) audio_pred = audio_pred.data.cpu().numpy() - audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() - assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred)) - audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() returned_sample = (audio_pred, audio_gt) if world_size > 1: @@ -632,8 +664,9 @@ def scan_pessimistic_batches_for_oom( batches, crit_values = find_pessimistic_batches(train_dl.sampler) for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] - audio, audio_lens, features, features_lens, tokens, tokens_lens = \ - prepare_input(batch, tokenizer, device) + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) try: # for discriminator with autocast(enabled=params.use_fp16): diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index 0fcbb92c1..81bb9ed13 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -29,10 +29,10 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, DynamicBucketingSampler, - SpeechSynthesisDataset, PrecomputedFeatures, SimpleCutSampler, SpecAugment, + SpeechSynthesisDataset, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py index 2a3dae900..6a067f596 100644 --- a/egs/ljspeech/TTS/vits/utils.py +++ b/egs/ljspeech/TTS/vits/utils.py @@ -14,15 +14,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union import collections import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler -from pathlib import Path from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -97,23 +97,23 @@ def plot_feature(spectrogram): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: import matplotlib + matplotlib.use("Agg") MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger('matplotlib') + mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) import matplotlib.pylab as plt import numpy as np fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", - interpolation='none') + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") plt.tight_layout() fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close() return data diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index d5e20a578..b4f0c21e6 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -9,8 +9,7 @@ from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn -from torch.cuda.amp import autocast - +from generator import VITSGenerator from hifigan import ( HiFiGANMultiPeriodDiscriminator, HiFiGANMultiScaleDiscriminator, @@ -25,9 +24,8 @@ from loss import ( KLDivergenceLoss, MelSpectrogramLoss, ) +from torch.cuda.amp import autocast from utils import get_segments -from generator import VITSGenerator - AVAILABLE_GENERATERS = { "vits_generator": VITSGenerator, @@ -42,8 +40,7 @@ AVAILABLE_DISCRIMINATORS = { class VITS(nn.Module): - """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech` - """ + """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`""" def __init__( self, diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py index fbe1be52b..5db461d5c 100644 --- a/egs/ljspeech/TTS/vits/wavenet.py +++ b/egs/ljspeech/TTS/vits/wavenet.py @@ -9,9 +9,8 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. """ -import math import logging - +import math from typing import Optional, Tuple import torch