diff --git a/egs/vctk/TTS/local/prepare_token_file.py b/egs/vctk/TTS/local/prepare_token_file.py index 7f0f3a815..c6636c3ad 100755 --- a/egs/vctk/TTS/local/prepare_token_file.py +++ b/egs/vctk/TTS/local/prepare_token_file.py @@ -17,19 +17,15 @@ """ -This file reads the texts in given manifest and generate the file that maps tokens to IDs. +This file reads the texts in given manifest and generates the file that maps tokens to IDs. """ import argparse import logging -from collections import Counter from pathlib import Path from typing import Dict -import g2p_en -import tacotron_cleaner.cleaners from lhotse import load_manifest -from tqdm import tqdm def get_args(): @@ -74,35 +70,24 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: def get_token2id(manifest_file: Path) -> Dict[str, int]: """Return a dict that maps token to IDs.""" - extra_tokens = { - "": 0, # blank - "": 1, # sos and eos symbols. - "": 2, # OOV - } - cut_set = load_manifest(manifest_file) - g2p = g2p_en.G2p() - counter = Counter() + extra_tokens = [ + "", # 0 for blank + "", # 1 for sos and eos symbols. + "", # 2 for OOV + ] + all_tokens = set() - for cut in tqdm(cut_set): + cut_set = load_manifest(manifest_file) + + for cut in cut_set: # Each cut only contain one supervision assert len(cut.supervisions) == 1, len(cut.supervisions) - text = cut.supervisions[0].text - # Text normalization - text = tacotron_cleaner.cleaners.custom_english_cleaners(text) - # Convert to phonemes - tokens = g2p(text) - for t in tokens: - counter[t] += 1 + for t in cut.tokens: + all_tokens.add(t) - # Sort by the number of occurrences in descending order - tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1]) + all_tokens = extra_tokens + list(all_tokens) - for token, idx in extra_tokens.items(): - tokens_and_counts.insert(idx, (token, None)) - - token2id: Dict[str, int] = { - token: i for i, (token, count) in enumerate(tokens_and_counts) - } + token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)} return token2id diff --git a/egs/vctk/TTS/local/prepare_tokens_vctk.py b/egs/vctk/TTS/local/prepare_tokens_vctk.py new file mode 100644 index 000000000..6ee29783a --- /dev/null +++ b/egs/vctk/TTS/local/prepare_tokens_vctk.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with phoneme tokens. +""" + +import logging +from pathlib import Path + +import g2p_en +import tacotron_cleaner.cleaners +from lhotse import CutSet, load_manifest + + +def prepare_tokens_vctk(): + output_dir = Path("data/spectrogram") + prefix = "vctk" + suffix = "jsonl.gz" + partition = "all" + + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + g2p = g2p_en.G2p() + + new_cuts = [] + for cut in cut_set: + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + text = cut.supervisions[0].normalized_text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + cut.tokens = g2p(text) + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_vctk() diff --git a/egs/vctk/TTS/prepare.sh b/egs/vctk/TTS/prepare.sh index 589c88245..f18a09fc7 100755 --- a/egs/vctk/TTS/prepare.sh +++ b/egs/vctk/TTS/prepare.sh @@ -66,7 +66,17 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Split the VCTK cuts into train, valid and test sets" + log "Stage 3: Prepare phoneme tokens for VCTK" + if [ ! -e data/spectrogram/.vctk_with_token.done ]; then + ./local/prepare_tokens_vctk.py + mv data/spectrogram/vctk_cuts_with_tokens_all.jsonl.gz \ + data/spectrogram/vctk_cuts_all.jsonl.gz + touch data/spectrogram/.vctk_with_token.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Split the VCTK cuts into train, valid and test sets" if [ ! -e data/spectrogram/.vctk_split.done ]; then lhotse subset --last 600 \ data/spectrogram/vctk_cuts_all.jsonl.gz \ @@ -88,8 +98,12 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi fi -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Generate token file" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate token file" + # We assume you have installed g2p_en and espnet_tts_frontend. + # If not, please install them with: + # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p + # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/tokens.txt ]; then ./local/prepare_token_file.py \ --manifest-file data/spectrogram/vctk_cuts_train.jsonl.gz \ diff --git a/egs/vctk/TTS/vits/duration_predictor.py b/egs/vctk/TTS/vits/duration_predictor.py index 1a8190014..c29a28479 100644 --- a/egs/vctk/TTS/vits/duration_predictor.py +++ b/egs/vctk/TTS/vits/duration_predictor.py @@ -14,6 +14,7 @@ from typing import Optional import torch import torch.nn.functional as F + from flow import ( ConvFlow, DilatedDepthSeparableConv, diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py new file mode 100755 index 000000000..2068adeea --- /dev/null +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a VITS model from PyTorch to ONNX. + +Export the model to ONNX: +./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +It will generate two files inside vits/exp: + - vits-epoch-1000.onnx + - vits-epoch-1000.int8.onnx (quantizated model) + +See ./test_onnx.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import torch +import torch.nn as nn +from onnxruntime.quantization import QuantType, quantize_dynamic +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for VITS generator.""" + + def __init__(self, model: nn.Module): + """ + Args: + model: + A VITS generator. + frame_shift: + The frame shift in samples. + """ + super().__init__() + self.model = model + + def forward( + self, + tokens: torch.Tensor, + tokens_lens: torch.Tensor, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of VITS.inference_batch + + Args: + tokens: + Input text token indexes (1, T_text) + tokens_lens: + Number of tokens of shape (1,) + noise_scale (float): + Noise scale parameter for flow. + noise_scale_dur (float): + Noise scale parameter for duration predictor. + alpha (float): + Alpha parameter to control the speed of generated speech. + + Returns: + Return a tuple containing: + - audio, generated wavform tensor, (B, T_wav) + """ + audio, _, _ = self.model.inference( + text=tokens, + text_lengths=tokens_lens, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + ) + return audio + + +def export_model_onnx( + model: nn.Module, + model_filename: str, + opset_version: int = 11, +) -> None: + """Export the given generator model to ONNX format. + The exported model has one input: + + - tokens, a tensor of shape (1, T_text); dtype is torch.int64 + + and it has one output: + + - audio, a tensor of shape (1, T'); dtype is torch.float32 + + Args: + model: + The VITS generator. + model_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_dur = torch.tensor([1], dtype=torch.float32) + alpha = torch.tensor([1], dtype=torch.float32) + + torch.onnx.export( + model, + (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha), + model_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "tokens", + "tokens_lens", + "noise_scale", + "noise_scale_dur", + "alpha", + ], + output_names=["audio"], + dynamic_axes={ + "tokens": {0: "N", 1: "T"}, + "tokens_lens": {0: "N"}, + "audio": {0: "N", 1: "T"}, + }, + ) + + meta_data = { + "model_type": "VITS", + "version": "1", + "model_author": "k2-fsa", + "comment": "VITS generator", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=model_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model = model.generator + model.to("cpu") + model.eval() + + model = OnnxModel(model=model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"generator parameters: {num_param}") + + suffix = f"epoch-{params.epoch}" + + opset_version = 13 + + logging.info("Exporting encoder") + model_filename = params.exp_dir / f"vits-{suffix}.onnx" + export_model_onnx( + model, + model_filename, + opset_version=opset_version, + ) + logging.info(f"Exported generator to {model_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + weight_type=QuantType.QUInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/vctk/TTS/vits/flow.py b/egs/vctk/TTS/vits/flow.py index 2b84f6434..206bd5e3e 100644 --- a/egs/vctk/TTS/vits/flow.py +++ b/egs/vctk/TTS/vits/flow.py @@ -13,6 +13,7 @@ import math from typing import Optional, Tuple, Union import torch + from transform import piecewise_rational_quadratic_transform diff --git a/egs/vctk/TTS/vits/generator.py b/egs/vctk/TTS/vits/generator.py index 634b2061a..efb0e254c 100644 --- a/egs/vctk/TTS/vits/generator.py +++ b/egs/vctk/TTS/vits/generator.py @@ -16,6 +16,9 @@ from typing import List, Optional, Tuple import numpy as np import torch import torch.nn.functional as F + +from icefall.utils import make_pad_mask + from duration_predictor import StochasticDurationPredictor from hifigan import HiFiGANGenerator from posterior_encoder import PosteriorEncoder @@ -23,8 +26,6 @@ from residual_coupling import ResidualAffineCouplingBlock from text_encoder import TextEncoder from utils import get_random_segments -from icefall.utils import make_pad_mask - class VITSGenerator(torch.nn.Module): """Generator module in VITS, `Conditional Variational Autoencoder @@ -402,6 +403,7 @@ class VITSGenerator(torch.nn.Module): """ # encoder x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + x_mask = x_mask.to(x.dtype) g = None if self.spks is not None: # (B, global_channels, 1) @@ -479,6 +481,7 @@ class VITSGenerator(torch.nn.Module): dur = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long() y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device) + y_mask = y_mask.to(x.dtype) attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) attn = self._generate_path(dur, attn_mask) diff --git a/egs/vctk/TTS/vits/infer.py b/egs/vctk/TTS/vits/infer.py index f01c5bbc4..42492da79 100755 --- a/egs/vctk/TTS/vits/infer.py +++ b/egs/vctk/TTS/vits/infer.py @@ -38,7 +38,7 @@ import torch.nn as nn import torchaudio from tokenizer import Tokenizer from train import get_model, get_params -from tts_datamodule import LJSpeechTtsDataModule +from tts_datamodule import VctkTtsDataModule from icefall.checkpoint import load_checkpoint from icefall.utils import AttributeDict, setup_logger @@ -94,6 +94,7 @@ def infer_dataset( tokenizer: Used to convert text to phonemes. """ + # Background worker save audios to disk. def _save_worker( batch_size: int, @@ -127,10 +128,10 @@ def infer_dataset( futures = [] with ThreadPoolExecutor(max_workers=1) as executor: for batch_idx, batch in enumerate(dl): - batch_size = len(batch["text"]) + batch_size = len(batch["tokens"]) - text = batch["text"] - tokens = tokenizer.texts_to_token_ids(text) + tokens = batch["tokens"] + tokens = tokenizer.tokens_to_token_ids(tokens) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] @@ -180,7 +181,7 @@ def infer_dataset( @torch.no_grad() def main(): parser = get_parser() - LJSpeechTtsDataModule.add_arguments(parser) + VctkTtsDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -224,7 +225,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - ljspeech = LJSpeechTtsDataModule(args) + ljspeech = VctkTtsDataModule(args) test_cuts = ljspeech.test_cuts() test_dl = ljspeech.test_dataloaders(test_cuts) @@ -236,6 +237,7 @@ def main(): tokenizer=tokenizer, ) + logging.info(f"Wav files are saved to {params.save_wav_dir}") logging.info("Done!") diff --git a/egs/vctk/TTS/vits/loss.py b/egs/vctk/TTS/vits/loss.py index 2f4dc9bc0..21aaad6e7 100644 --- a/egs/vctk/TTS/vits/loss.py +++ b/egs/vctk/TTS/vits/loss.py @@ -14,6 +14,7 @@ from typing import List, Tuple, Union import torch import torch.distributions as D import torch.nn.functional as F + from lhotse.features.kaldi import Wav2LogFilterBank diff --git a/egs/vctk/TTS/vits/posterior_encoder.py b/egs/vctk/TTS/vits/posterior_encoder.py index 1104fb864..6b8a5be52 100644 --- a/egs/vctk/TTS/vits/posterior_encoder.py +++ b/egs/vctk/TTS/vits/posterior_encoder.py @@ -12,9 +12,9 @@ This code is based on https://github.com/jaywalnut310/vits. from typing import Optional, Tuple import torch -from wavenet import Conv1d, WaveNet from icefall.utils import make_pad_mask +from wavenet import WaveNet, Conv1d class PosteriorEncoder(torch.nn.Module): diff --git a/egs/vctk/TTS/vits/residual_coupling.py b/egs/vctk/TTS/vits/residual_coupling.py index f9a2a3786..2d6807cb7 100644 --- a/egs/vctk/TTS/vits/residual_coupling.py +++ b/egs/vctk/TTS/vits/residual_coupling.py @@ -12,6 +12,7 @@ This code is based on https://github.com/jaywalnut310/vits. from typing import Optional, Tuple, Union import torch + from flow import FlipFlow from wavenet import WaveNet diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py new file mode 100755 index 000000000..8acca7c02 --- /dev/null +++ b/egs/vctk/TTS/vits/test_onnx.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to test the exported onnx model by vits/export-onnx.py + +Use the onnx model to generate a wav: +./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt +""" + + +import argparse +import logging +import onnxruntime as ort +import torch +import torchaudio + +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the onnx model.", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +class OnnxModel: + def __init__(self, model_filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + + def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: + """ + Args: + tokens: + A 1-D tensor of shape (1, T) + Returns: + A tensor of shape (1, T') + """ + noise_scale = torch.tensor([0.667], dtype=torch.float32) + noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) + alpha = torch.tensor([1.0], dtype=torch.float32) + + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: tokens.numpy(), + self.model.get_inputs()[1].name: tokens_lens.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: noise_scale_dur.numpy(), + self.model.get_inputs()[4].name: alpha.numpy(), + }, + )[0] + return torch.from_numpy(out) + + +def main(): + args = get_parser().parse_args() + + tokenizer = Tokenizer(args.tokens) + + logging.info("About to create onnx model") + model = OnnxModel(args.model_filename) + + text = "I went there to see the land, the people and how their system works, end quote." + tokens = tokenizer.texts_to_token_ids([text]) + tokens = torch.tensor(tokens) # (1, T) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) + audio = model(tokens, tokens_lens) # (1, T') + + torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) + logging.info("Saved to test_onnx.wav") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/vctk/TTS/vits/text_encoder.py b/egs/vctk/TTS/vits/text_encoder.py index 7176456de..fcbae7103 100644 --- a/egs/vctk/TTS/vits/text_encoder.py +++ b/egs/vctk/TTS/vits/text_encoder.py @@ -30,7 +30,7 @@ from typing import Optional, Tuple import torch from torch import Tensor, nn -from icefall.utils import make_pad_mask +from icefall.utils import is_jit_tracing, make_pad_mask class TextEncoder(torch.nn.Module): @@ -442,18 +442,30 @@ class RelPositionMultiheadAttention(nn.Module): """ (batch_size, num_heads, seq_len, n) = x.shape - assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1" + if not is_jit_tracing(): + assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1" - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, seq_len, seq_len), - (batch_stride, head_stride, time_stride - n_stride, n_stride), - storage_offset=n_stride * (seq_len - 1), - ) + if is_jit_tracing(): + rows = torch.arange(start=seq_len - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, seq_len, seq_len) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, seq_len, seq_len), + (batch_stride, head_stride, time_stride - n_stride, n_stride), + storage_offset=n_stride * (seq_len - 1), + ) def forward( self, diff --git a/egs/vctk/TTS/vits/tokenizer.py b/egs/vctk/TTS/vits/tokenizer.py index 8a61511ef..70f1240b4 100644 --- a/egs/vctk/TTS/vits/tokenizer.py +++ b/egs/vctk/TTS/vits/tokenizer.py @@ -77,3 +77,32 @@ class Tokenizer(object): token_ids_list.append(token_ids) return token_ids_list + + def tokens_to_token_ids( + self, tokens_list: List[str], intersperse_blank: bool = True + ): + """ + Args: + tokens_list: + A list of token list, each corresponding to one utterance. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for tokens in tokens_list: + token_ids = [] + for t in tokens: + if t in self.token2id: + token_ids.append(self.token2id[t]) + else: + token_ids.append(self.oov_id) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.blank_id) + token_ids_list.append(token_ids) + + return token_ids_list diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py index 0c6ca1b4d..1dfe92685 100755 --- a/egs/vctk/TTS/vits/train.py +++ b/egs/vctk/TTS/vits/train.py @@ -34,7 +34,7 @@ from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import LJSpeechTtsDataModule +from tts_datamodule import VctkTtsDataModule from utils import MetricsTracker, plot_feature, save_checkpoint from vits import VITS @@ -294,10 +294,9 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): features = batch["features"].to(device) audio_lens = batch["audio_lens"].to(device) features_lens = batch["features_lens"].to(device) - text = batch["text"] - speakers = batch["speakers"] + tokens = batch["tokens"] - tokens = tokenizer.texts_to_token_ids(text) + tokens = tokenizer.tokens_to_token_ids(tokens) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] @@ -306,7 +305,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): # a tensor of shape (B, T) tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) - return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers + return audio, audio_lens, features, features_lens, tokens, tokens_lens def train_one_epoch( @@ -384,16 +383,10 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 - batch_size = len(batch["text"]) - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - speakers, - ) = prepare_input(batch, tokenizer, device) + batch_size = len(batch["tokens"]) + audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( + batch, tokenizer, device + ) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -582,7 +575,7 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): - batch_size = len(batch["text"]) + batch_size = len(batch["tokens"]) ( audio, audio_lens, @@ -810,7 +803,7 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - ljspeech = LJSpeechTtsDataModule(args) + ljspeech = VctkTtsDataModule(args) train_cuts = ljspeech.train_cuts() @@ -914,7 +907,7 @@ def run(rank, world_size, args): def main(): parser = get_parser() - LJSpeechTtsDataModule.add_arguments(parser) + VctkTtsDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) diff --git a/egs/vctk/TTS/vits/tts_datamodule.py b/egs/vctk/TTS/vits/tts_datamodule.py index d2064c5e3..93f39e329 100644 --- a/egs/vctk/TTS/vits/tts_datamodule.py +++ b/egs/vctk/TTS/vits/tts_datamodule.py @@ -52,7 +52,7 @@ class _SeedWorkers: fix_random_seed(self.seed + worker_id) -class LJSpeechTtsDataModule: +class VctkTtsDataModule: """ DataModule for tts experiments. It assumes there is always one train and valid dataloader, @@ -168,7 +168,8 @@ class LJSpeechTtsDataModule: """ logging.info("About to create train dataset") train = SpeechSynthesisDataset( - return_tokens=False, + return_text=False, + return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -182,7 +183,8 @@ class LJSpeechTtsDataModule: use_fft_mag=True, ) train = SpeechSynthesisDataset( - return_tokens=False, + return_text=False, + return_tokens=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), return_cuts=self.args.return_cuts, ) @@ -236,13 +238,15 @@ class LJSpeechTtsDataModule: use_fft_mag=True, ) validate = SpeechSynthesisDataset( - return_tokens=False, + return_text=False, + return_tokens=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), return_cuts=self.args.return_cuts, ) else: validate = SpeechSynthesisDataset( - return_tokens=False, + return_text=False, + return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -273,13 +277,15 @@ class LJSpeechTtsDataModule: use_fft_mag=True, ) test = SpeechSynthesisDataset( - return_tokens=False, + return_text=False, + return_tokens=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), return_cuts=self.args.return_cuts, ) else: test = SpeechSynthesisDataset( - return_tokens=False, + return_text=False, + return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) diff --git a/egs/vctk/TTS/vits/utils.py b/egs/vctk/TTS/vits/utils.py index 6a067f596..12b2d6b81 100644 --- a/egs/vctk/TTS/vits/utils.py +++ b/egs/vctk/TTS/vits/utils.py @@ -14,15 +14,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, List, Optional, Tuple, Union import collections import logging -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union import torch -import torch.distributed as dist import torch.nn as nn +import torch.distributed as dist from lhotse.dataset.sampling.base import CutSampler +from pathlib import Path from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer diff --git a/egs/vctk/TTS/vits/vits.py b/egs/vctk/TTS/vits/vits.py index 2c38d5d37..6db1cdee1 100644 --- a/egs/vctk/TTS/vits/vits.py +++ b/egs/vctk/TTS/vits/vits.py @@ -9,7 +9,8 @@ from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn -from generator import VITSGenerator +from torch.cuda.amp import autocast + from hifigan import ( HiFiGANMultiPeriodDiscriminator, HiFiGANMultiScaleDiscriminator, @@ -24,8 +25,9 @@ from loss import ( KLDivergenceLoss, MelSpectrogramLoss, ) -from torch.cuda.amp import autocast from utils import get_segments +from generator import VITSGenerator + AVAILABLE_GENERATERS = { "vits_generator": VITSGenerator, @@ -570,6 +572,7 @@ class VITS(nn.Module): self, text: torch.Tensor, text_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, durations: Optional[torch.Tensor] = None, noise_scale: float = 0.667, noise_scale_dur: float = 0.8, @@ -582,6 +585,7 @@ class VITS(nn.Module): Args: text (Tensor): Input text index tensor (B, T_text). text_lengths (Tensor): Input text index tensor (B,). + sids (Tensor): Speaker index tensor (B,). noise_scale (float): Noise scale value for flow. noise_scale_dur (float): Noise scale value for duration predictor. alpha (float): Alpha parameter to control the speed of generated speech. @@ -596,6 +600,7 @@ class VITS(nn.Module): wav, att_w, dur = self.generator.inference( text=text, text_lengths=text_lengths, + sids=sids, noise_scale=noise_scale, noise_scale_dur=noise_scale_dur, alpha=alpha, diff --git a/egs/vctk/TTS/vits/wavenet.py b/egs/vctk/TTS/vits/wavenet.py index 5db461d5c..fbe1be52b 100644 --- a/egs/vctk/TTS/vits/wavenet.py +++ b/egs/vctk/TTS/vits/wavenet.py @@ -9,8 +9,9 @@ This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. """ -import logging import math +import logging + from typing import Optional, Tuple import torch