diff --git a/egs/ljspeech/tts/local/prepare_token_file.py b/egs/ljspeech/tts/local/prepare_token_file.py new file mode 100755 index 000000000..17a558899 --- /dev/null +++ b/egs/ljspeech/tts/local/prepare_token_file.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and generate the file that maps tokens to IDs. +""" + +import argparse +import logging +from collections import Counter +from pathlib import Path +from typing import Dict + +import g2p_en +import tacotron_cleaner.cleaners +from lhotse import load_manifest + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-file", + type=Path, + default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"), + help="Path to the manifest file", + ) + + parser.add_argument( + "--tokens", + type=Path, + default=Path("data/tokens.txt"), + help="Path to the tokens", + ) + + return parser.parse_args() + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_token2id(manifest_file: Path) -> Dict[str, int]: + """Return a dict that maps token to IDs.""" + extra_tokens = { + "": 0, # blank + "": 1, # sos and eos symbols. + "": 2, # OOV + } + cut_set = load_manifest(manifest_file) + g2p = g2p_en.G2p() + counter = Counter() + + for cut in cut_set: + # Each cut only contain one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + text = cut.supervisions[0].normalized_text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + tokens = g2p(text) + for t in tokens: + counter[t] += 1 + + # Sort by the number of occurrences in descending order + tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1]) + + for token, idx in extra_tokens.items(): + tokens_and_counts.insert(idx, (token, None)) + + token2id: Dict[str, int] = {token: i for i, (token, count) in enumerate(tokens_and_counts)} + return token2id + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + manifest_file = Path(args.manifest_file) + out_file = Path(args.tokens) + + token2id = get_token2id(manifest_file) + write_mapping(out_file, token2id) diff --git a/egs/ljspeech/tts/local/split_subsets.py b/egs/ljspeech/tts/local/split_subsets.py index 328cdd691..b2afca971 100755 --- a/egs/ljspeech/tts/local/split_subsets.py +++ b/egs/ljspeech/tts/local/split_subsets.py @@ -52,7 +52,8 @@ def main(): manifest_dir = Path(args.manifest_dir) prefix = "ljspeech" suffix = "jsonl.gz" - all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}") + # all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}") + all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all_phonemized.{suffix}") cut_ids = list(all_cuts.ids) random.shuffle(cut_ids) diff --git a/egs/ljspeech/tts/prepare.sh b/egs/ljspeech/tts/prepare.sh index f78964c34..4f4685951 100755 --- a/egs/ljspeech/tts/prepare.sh +++ b/egs/ljspeech/tts/prepare.sh @@ -66,11 +66,50 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi fi +# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then +# log "Stage 3: Phonemize the transcripts for LJSpeech" +# if [ ! -e data/spectrogram/.ljspeech_phonemized.done ]; then +# ./local/phonemize_text.py data/spectrogram +# touch data/spectrogram/.ljspeech_phonemized.done +# fi +# fi + +# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then +# log "Stage 4: Split the LJSpeech cuts into three sets" +# if [ ! -e data/spectrogram/.ljspeech_split.done ]; then +# ./local/split_subsets.py data/spectrogram +# touch data/spectrogram/.ljspeech_split.done +# fi +# fi + if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Split the LJSpeech cuts into three sets" + log "Stage 3: Split the LJSpeech cuts into train, valid and test sets" if [ ! -e data/spectrogram/.ljspeech_split.done ]; then - ./local/split_subsets.py data/spectrogram - touch data/spectrogram/.ljspeech_split.done + lhotse subset --last 600 \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ + data/spectrogram/ljspeech_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ + data/spectrogram/ljspeech_cuts_test.jsonl.gz + rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 )) + lhotse subset --first $n \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_train.jsonl.gz + touch data/spectrogram/.ljspeech_split.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Generate token file" + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py \ + --manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \ + --tokens data/tokens.txt fi fi diff --git a/egs/ljspeech/tts/vits/generator.py b/egs/ljspeech/tts/vits/generator.py index dbf503944..a74440c95 100644 --- a/egs/ljspeech/tts/vits/generator.py +++ b/egs/ljspeech/tts/vits/generator.py @@ -515,10 +515,12 @@ class VITSGenerator(torch.nn.Module): cum_dur_flat = cum_dur.view(b * t_x) path = torch.arange(t_y, dtype=dur.dtype, device=dur.device) path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1) - path = path.view(b, t_x, t_y).to(dtype=mask.dtype) + # path = path.view(b, t_x, t_y).to(dtype=mask.dtype) + path = path.view(b, t_x, t_y).to(dtype=torch.float) # path will be like (t_x = 3, t_y = 5): # [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.], # [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.], # [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]] path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1] + # path = path.to(dtype=mask.dtype) return path.unsqueeze(1).transpose(2, 3) * mask diff --git a/egs/ljspeech/tts/vits/infer.py b/egs/ljspeech/tts/vits/infer.py new file mode 100755 index 000000000..89fc72962 --- /dev/null +++ b/egs/ljspeech/tts/vits/infer.py @@ -0,0 +1,366 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torchaudio + +from train2 import get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) +from tts_datamodule import LJSpeechTtsDataModule +from utils import prepare_token_batch + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + return parser + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + # Background worker save audios to disk. + def _save_worker( + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + audio_lens_pred: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), + audio[i:i + 1, :audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), + audio_pred[i:i + 1, :audio_lens_pred[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 10 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + # We only want one background worker so that serialization is deterministic. + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["text"]) + text = batch["text"] + tokens, tokens_lens = prepare_token_batch(text) + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens) + audio_pred = audio_pred.detach().cpu() + # convert to samples + audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + + # import pdb + # pdb.set_trace() + + futures.append( + executor.submit( + _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + # return results + for f in futures: + f.result() + + +@torch.no_grad() +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + ljspeech = LJSpeechTtsDataModule(args) + + test_cuts = ljspeech.test_cuts() + test_dl = ljspeech.test_dataloaders(test_cuts) + + infer_dataset( + dl=test_dl, + params=params, + model=model, + ) + + # save_results( + # params=params, + # test_set_name=test_set, + # results_dict=results_dict, + # ) + + logging.info("Done!") + + +# torch.set_num_threads(1) +# torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/tts/vits/loss.py b/egs/ljspeech/tts/vits/loss.py index d322f5e05..0d27af643 100644 --- a/egs/ljspeech/tts/vits/loss.py +++ b/egs/ljspeech/tts/vits/loss.py @@ -241,7 +241,8 @@ class MelSpectrogramLoss(torch.nn.Module): self, y_hat: torch.Tensor, y: torch.Tensor, - ) -> torch.Tensor: + return_mel: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: """Calculate Mel-spectrogram loss. Args: @@ -259,6 +260,9 @@ class MelSpectrogramLoss(torch.nn.Module): mel = self.wav_to_mel(y.squeeze(1)) mel_loss = F.l1_loss(mel_hat, mel) + if return_mel: + return mel_loss, (mel_hat, mel) + return mel_loss diff --git a/egs/ljspeech/tts/vits/tokenizer.py b/egs/ljspeech/tts/vits/tokenizer.py new file mode 100644 index 000000000..5a513a0d9 --- /dev/null +++ b/egs/ljspeech/tts/vits/tokenizer.py @@ -0,0 +1,80 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List + +import g2p_en +import tacotron_cleaner.cleaners + +from utils import intersperse + + +class Tokenizer(object): + def __init__(self, tokens: str): + """ + Args: + tokens: the file that maps tokens to ids + """ + # Parse token file + self.token2id: Dict[str, int] = {} + with open(tokens, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + id = int(info[0]) + else: + token, id = info[0], int(info[1]) + self.token2id[token] = id + + self.blank_id = self.token2id[""] + self.oov_id = self.token2id[""] + self.vocab_size = len(self.token2id) + + self.g2p = g2p_en.G2p() + + def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True): + """ + Args: + texts: + A list of transcripts. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for text in texts: + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + tokens = self.g2p(text) + token_ids = [] + for t in tokens: + if t in self.token2id: + token_ids.append(self.token2id[t]) + else: + token_ids.append(self.oov_id) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.blank_id) + + token_ids_list.append(token_ids) + + return token_ids_list diff --git a/egs/ljspeech/tts/vits/train.py b/egs/ljspeech/tts/vits/train.py index 8fd2a596a..01cd6137e 100755 --- a/egs/ljspeech/tts/vits/train.py +++ b/egs/ljspeech/tts/vits/train.py @@ -1,10 +1,32 @@ #!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import argparse import logging from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Union +import k2 import torch import torch.multiprocessing as mp import torch.nn as nn @@ -27,10 +49,10 @@ from icefall.utils import ( str2bool, ) -from symbols import symbol_table +from tokenizer import Tokenizer from utils import ( MetricsTracker, - prepare_token_batch, + plot_feature, save_checkpoint, save_checkpoint_with_global_batch_idx, ) @@ -101,6 +123,13 @@ def get_parser(): """, ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to tokens.txt.""", + ) + parser.add_argument( "--lr", type=float, default=2.0e-4, help="The base learning rate." ) @@ -213,16 +242,16 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": -1, # 0 - "log_interval": 50, + "log_interval": 10, + "draw_interval": 500, # "reset_interval": 200, - "valid_interval": 500, + "valid_interval": 200, "env_info": get_env_info(), "sampling_rate": 22050, + "frame_shift": 256, + "frame_length": 1024, "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length - "vocab_size": len(symbol_table), "mel_loss_params": { - "frame_shift": 256, - "frame_length": 1024, "n_mels": 80, }, "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss @@ -287,11 +316,16 @@ def load_checkpoint_if_available( def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = params.mel_loss_params + mel_loss_params.update( + frame_length=params.frame_length, + frame_shift=params.frame_shift, + ) model = VITS( vocab_size=params.vocab_size, feature_dim=params.feature_dim, sampling_rate=params.sampling_rate, - mel_loss_params=params.mel_loss_params, + mel_loss_params=mel_loss_params, lambda_adv=params.lambda_adv, lambda_mel=params.lambda_mel, lambda_feat_match=params.lambda_feat_match, @@ -301,79 +335,30 @@ def get_model(params: AttributeDict) -> nn.Module: return model -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + text = batch["text"] - # used to summary the stats over iterations - tot_loss = MetricsTracker() + tokens = tokenizer.texts_to_token_ids(text) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) - with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): - batch_size = len(batch["text"]) - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - text = batch["text"] - tokens, tokens_lens = prepare_token_batch(text) - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) - - loss_info = MetricsTracker() - loss_info['samples'] = batch_size - - # forward discriminator - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=False, - ) - assert loss_d.requires_grad is False - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - - # forward generator - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - ) - assert loss_g.requires_grad is False - for k, v in stats_g.items(): - loss_info[k] = v * batch_size - - # summary stats - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(device) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss + return audio, audio_lens, features, features_lens, tokens, tokens_lens def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], + tokenizer: Tokenizer, optimizer_g: Optimizer, optimizer_d: Optimizer, scheduler_g: LRSchedulerType, @@ -442,18 +427,13 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["text"]) - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - text = batch["text"] - tokens, tokens_lens = prepare_token_batch(text) - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) loss_info = MetricsTracker() loss_info['samples'] = batch_size + return_sample = params.batch_idx_train % params.log_interval == 0 try: with autocast(enabled=params.use_fp16): # forward discriminator @@ -483,9 +463,13 @@ def train_one_epoch( speech=audio, speech_lengths=audio_lens, forward_generator=True, + return_sample=return_sample, ) for k, v in stats_g.items(): - loss_info[k] = v * batch_size + if "return_sample" not in k: + loss_info[k] = v * batch_size + if return_sample: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["return_sample"] # update generator optimizer_g.zero_grad() scaler.scale(loss_g).backward() @@ -577,13 +561,27 @@ def train_one_epoch( tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) + if return_sample: + tb_writer.add_audio( + "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_audio( + "train/speech_", speech_, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_image( + "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC' + ) + tb_writer.add_image( + "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC' + ) # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") - valid_info = compute_validation_loss( + valid_info, (speech_hat, speech) = compute_validation_loss( params=params, model=model, + tokenizer=tokenizer, valid_dl=valid_dl, world_size=world_size, ) @@ -596,6 +594,12 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) + tb_writer.add_audio( + "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_audio( + "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate + ) loss_value = tot_loss["generator_loss"] / tot_loss["samples"] params.train_loss = loss_value @@ -604,9 +608,87 @@ def train_one_epoch( params.best_train_loss = params.train_loss +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + return_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["text"]) + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) + + loss_info = MetricsTracker() + loss_info['samples'] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference(text=tokens[0, :tokens_lens[0].item()]) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred)) + audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy() + return_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, return_sample + + def scan_pessimistic_batches_for_oom( model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, optimizer_g: torch.optim.Optimizer, optimizer_d: torch.optim.Optimizer, params: AttributeDict, @@ -620,14 +702,8 @@ def scan_pessimistic_batches_for_oom( batches, crit_values = find_pessimistic_batches(train_dl.sampler) for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - text = batch["text"] - tokens, tokens_lens = prepare_token_batch(text) - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) try: # for discriminator with autocast(enabled=params.use_fp16): @@ -702,6 +778,11 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + logging.info(params) logging.info("About to create model") @@ -728,14 +809,14 @@ def run(rank, world_size, args): lr=params.lr, betas=(0.8, 0.99), eps=1e-9, - weight_decay=0, + # weight_decay=0, ) optimizer_d = torch.optim.AdamW( discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9, - weight_decay=0, + # weight_decay=0, ) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) @@ -804,6 +885,7 @@ def run(rank, world_size, args): scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, + tokenizer=tokenizer, optimizer_g=optimizer_g, optimizer_d=optimizer_d, params=params, @@ -815,6 +897,8 @@ def run(rank, world_size, args): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) @@ -826,6 +910,7 @@ def run(rank, world_size, args): train_one_epoch( params=params, model=model, + tokenizer=tokenizer, optimizer_g=optimizer_g, optimizer_d=optimizer_d, scheduler_g=scheduler_g, diff --git a/egs/ljspeech/tts/vits/tts_datamodule.py b/egs/ljspeech/tts/vits/tts_datamodule.py index bd67aa6b1..40e9c19dd 100644 --- a/egs/ljspeech/tts/vits/tts_datamodule.py +++ b/egs/ljspeech/tts/vits/tts_datamodule.py @@ -131,7 +131,14 @@ class LJSpeechTtsDataModule: default=True, help="Whether to drop last batch. Used by sampler.", ) - + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) group.add_argument( "--num-workers", type=int, @@ -163,6 +170,7 @@ class LJSpeechTtsDataModule: train = SpeechSynthesisDataset( return_tokens=False, feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, ) if self.args.on_the_fly_feats: @@ -176,6 +184,7 @@ class LJSpeechTtsDataModule: train = SpeechSynthesisDataset( return_tokens=False, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, ) if self.args.bucketing_sampler: @@ -229,11 +238,13 @@ class LJSpeechTtsDataModule: validate = SpeechSynthesisDataset( return_tokens=False, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, ) else: validate = SpeechSynthesisDataset( return_tokens=False, feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, ) valid_sampler = DynamicBucketingSampler( cuts_valid, @@ -264,11 +275,13 @@ class LJSpeechTtsDataModule: test = SpeechSynthesisDataset( return_tokens=False, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, ) else: test = SpeechSynthesisDataset( return_tokens=False, feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, ) test_sampler = DynamicBucketingSampler( cuts, diff --git a/egs/ljspeech/tts/vits/utils.py b/egs/ljspeech/tts/vits/utils.py index 002097581..582856eee 100644 --- a/egs/ljspeech/tts/vits/utils.py +++ b/egs/ljspeech/tts/vits/utils.py @@ -211,6 +211,7 @@ def intersperse(sequence, item=0): def prepare_token_batch( texts: List[str], + phonemes: Optional[List[str]] = None, intersperse_blank: bool = True, blank_id: int = 0, pad_id: int = 0, @@ -222,41 +223,50 @@ def prepare_token_batch( blank_id: index of blank token pad_id: padding index """ - # normalize text - normalized_texts = [] - for text in texts: - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_abbreviations(text) - normalized_texts.append(text) + if phonemes is None: + # normalize text + normalized_texts = [] + for text in texts: + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + normalized_texts.append(text) - # convert to phonemes - phonemes = phonemize( - normalized_texts, - language='en-us', - backend='espeak', - strip=True, - preserve_punctuation=True, - with_stress=True, - ) + # convert to phonemes + phonemes = phonemize( + normalized_texts, + language='en-us', + backend='espeak', + strip=True, + preserve_punctuation=True, + with_stress=True, + ) + phonemes = [collapse_whitespace(sequence) for sequence in phonemes] # convert to symbol ids lengths = [] sequences = [] + skip = False for idx, sequence in enumerate(phonemes): try: - sequence = [symbol_to_id[symbol] for symbol in collapse_whitespace(sequence)] - except RuntimeError: - print(text[idx]) - print(normalized_texts[idx]) + sequence = [symbol_to_id[symbol] for symbol in sequence] + except Exception: + # print(texts[idx]) + # print(normalized_texts[idx]) + print(phonemes[idx]) + skip = True if intersperse_blank: sequence = intersperse(sequence, blank_id) - sequences.append(torch.tensor(sequence, dtype=torch.int64)) + try: + sequences.append(torch.tensor(sequence, dtype=torch.int64)) + except Exception: + print(sequence) + skip = True lengths.append(len(sequence)) sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id) lengths = torch.tensor(lengths, dtype=torch.int64) - return sequences, lengths + return sequences, lengths, skip class MetricsTracker(collections.defaultdict): @@ -287,7 +297,7 @@ class MetricsTracker(collections.defaultdict): norm_value = "%.4g" % v ans += str(k) + "=" + str(norm_value) + ", " samples = "%.2f" % self["samples"] - ans += "over" + str(samples) + " samples." + ans += "over " + str(samples) + " samples." return ans def norm_items(self) -> List[Tuple[str, float]]: @@ -468,3 +478,41 @@ def save_checkpoint_with_global_batch_idx( sampler=sampler, rank=rank, ) + + +# def plot_feature(feature): +# """ +# Display the feature matrix as an image. Requires matplotlib to be installed. +# """ +# import matplotlib.pyplot as plt +# +# feature = np.flip(feature.transpose(1, 0), 0) +# return plt.matshow(feature) + +MATPLOTLIB_FLAG = False + + +def plot_feature(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data diff --git a/egs/ljspeech/tts/vits/vits.py b/egs/ljspeech/tts/vits/vits.py index da9d144f2..441e915df 100644 --- a/egs/ljspeech/tts/vits/vits.py +++ b/egs/ljspeech/tts/vits/vits.py @@ -241,6 +241,7 @@ class VITS(nn.Module): feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, + return_sample: bool = False, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, @@ -276,6 +277,7 @@ class VITS(nn.Module): feats_lengths=feats_lengths, speech=speech, speech_lengths=speech_lengths, + return_sample=return_sample, sids=sids, spembs=spembs, lids=lids, @@ -301,6 +303,7 @@ class VITS(nn.Module): feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, + return_sample: bool = False, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, @@ -367,7 +370,12 @@ class VITS(nn.Module): # calculate losses with autocast(enabled=False): - mel_loss = self.mel_loss(speech_hat_, speech_) + if not return_sample: + mel_loss = self.mel_loss(speech_hat_, speech_) + else: + mel_loss, (mel_hat_, mel_) = self.mel_loss( + speech_hat_, speech_, return_mel=True + ) kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) dur_loss = torch.sum(dur_nll.float()) adv_loss = self.generator_adv_loss(p_hat) @@ -389,6 +397,14 @@ class VITS(nn.Module): generator_feat_match_loss=feat_match_loss.item(), ) + if return_sample: + stats["return_sample"] = ( + speech_hat_[0].data.cpu().numpy(), + speech_[0].data.cpu().numpy(), + mel_hat_[0].data.cpu().numpy(), + mel_[0].data.cpu().numpy(), + ) + # reset cache if reuse_cache or not self.training: self._cache = None @@ -564,4 +580,43 @@ class VITS(nn.Module): alpha=alpha, max_len=max_len, ) - return dict(wav=wav.view(-1), att_w=att_w[0], duration=dur[0]) + return wav.view(-1), att_w[0], dur[0] + + def inference_batch( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + durations: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Dict[str, torch.Tensor]: + """Run inference. + + Args: + text (Tensor): Input text index tensor (B, T_text). + text_lengths (Tensor): Input text index tensor (B,). + noise_scale (float): Noise scale value for flow. + noise_scale_dur (float): Noise scale value for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length. + + Returns: + Dict[str, Tensor]: + * wav (Tensor): Generated waveform tensor (B, T_wav). + * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text). + * duration (Tensor): Predicted duration tensor (B, T_text). + + """ + # inference + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + max_len=max_len, + ) + return wav, att_w, dur