From e4d40baaf57c4c9a5e8cb79da0899bcd2c1340d1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 26 Dec 2024 11:51:56 +0800 Subject: [PATCH] ready to train --- egs/baker_zh/TTS/local/audio.py | 1 + .../TTS/local/compute_fbank_baker_zh.py | 110 +++ .../TTS/local/compute_fbank_statistics.py | 84 ++ .../TTS/local/convert_text_to_tokens.py | 119 +++ egs/baker_zh/TTS/local/fbank.py | 1 + egs/baker_zh/TTS/local/generate_tokens.py | 6 +- egs/baker_zh/TTS/local/validate_manifest.py | 70 ++ egs/baker_zh/TTS/matcha/tokenizer.py | 120 ++- egs/baker_zh/TTS/matcha/train.py | 717 ++++++++++++++++++ egs/baker_zh/TTS/matcha/tts_datamodule.py | 340 +++++++++ egs/baker_zh/TTS/prepare.sh | 67 ++ 11 files changed, 1633 insertions(+), 2 deletions(-) create mode 120000 egs/baker_zh/TTS/local/audio.py create mode 100755 egs/baker_zh/TTS/local/compute_fbank_baker_zh.py create mode 100755 egs/baker_zh/TTS/local/compute_fbank_statistics.py create mode 100755 egs/baker_zh/TTS/local/convert_text_to_tokens.py create mode 120000 egs/baker_zh/TTS/local/fbank.py mode change 100644 => 100755 egs/baker_zh/TTS/local/generate_tokens.py create mode 100755 egs/baker_zh/TTS/local/validate_manifest.py mode change 120000 => 100644 egs/baker_zh/TTS/matcha/tokenizer.py create mode 100755 egs/baker_zh/TTS/matcha/train.py create mode 100644 egs/baker_zh/TTS/matcha/tts_datamodule.py diff --git a/egs/baker_zh/TTS/local/audio.py b/egs/baker_zh/TTS/local/audio.py new file mode 120000 index 000000000..b70d91c92 --- /dev/null +++ b/egs/baker_zh/TTS/local/audio.py @@ -0,0 +1 @@ +../matcha/audio.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py b/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py new file mode 100755 index 000000000..0720158f2 --- /dev/null +++ b/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the baker-zh dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from fbank import MatchaFbank, MatchaFbankConfig +from lhotse import CutSet, LilcomChunkyWriter, load_manifest +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=4, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + return parser + + +def compute_fbank_baker_zh(num_jobs: int): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + if num_jobs < 1: + num_jobs = os.cpu_count() + + logging.info(f"num_jobs: {num_jobs}") + logging.info(f"src_dir: {src_dir}") + logging.info(f"output_dir: {output_dir}") + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=22050, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + + prefix = "baker_zh" + suffix = "jsonl.gz" + + extractor = MatchaFbank(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts.{suffix}" + logging.info(f"Processing {cuts_filename}") + cut_set = load_manifest(src_dir / cuts_filename).resample(22050) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats", + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + # Torch's multithreaded behavior needs to be disabled or + # it wastes a lot of CPU and slow things down. + # Do this outside of main() in case it needs to take effect + # even when we are not invoking the main (e.g. when spawning subprocesses). + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_parser().parse_args() + compute_fbank_baker_zh(args.num_jobs) diff --git a/egs/baker_zh/TTS/local/compute_fbank_statistics.py b/egs/baker_zh/TTS/local/compute_fbank_statistics.py new file mode 100755 index 000000000..d0232c983 --- /dev/null +++ b/egs/baker_zh/TTS/local/compute_fbank_statistics.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script compute the mean and std of the fbank features. +""" + +import argparse +import json +import logging +from pathlib import Path + +import torch +from lhotse import CutSet, load_manifest_lazy + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + parser.add_argument( + "cmvn", + type=Path, + help="Path to the cmvn.json", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info( + f"Computing fbank mean and std for {manifest} and saving to {args.cmvn}" + ) + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet), type(cut_set) + + feat_dim = cut_set[0].features.num_features + num_frames = 0 + s = 0 + sq = 0 + for c in cut_set: + f = torch.from_numpy(c.load_features()) + num_frames += f.shape[0] + s += f.sum() + sq += f.square().sum() + + fbank_mean = s / (num_frames * feat_dim) + fbank_var = sq / (num_frames * feat_dim) - fbank_mean * fbank_mean + print("fbank var", fbank_var) + fbank_std = fbank_var.sqrt() + with open(args.cmvn, "w") as f: + json.dump({"fbank_mean": fbank_mean.item(), "fbank_std": fbank_std.item()}, f) + f.write("\n") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/baker_zh/TTS/local/convert_text_to_tokens.py b/egs/baker_zh/TTS/local/convert_text_to_tokens.py new file mode 100755 index 000000000..a20165089 --- /dev/null +++ b/egs/baker_zh/TTS/local/convert_text_to_tokens.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 + +import argparse +import re +from typing import List + +import jieba +from lhotse import load_manifest +from pypinyin import lazy_pinyin, load_phrases_dict, Style + +load_phrases_dict( + { + "行长": [["hang2"], ["zhang3"]], + "银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]], + } +) + +whiter_space_re = re.compile(r"\s+") + +punctuations_re = [ + (re.compile(x[0], re.IGNORECASE), x[1]) + for x in [ + (",", ","), + ("。", "."), + ("!", "!"), + ("?", "?"), + ("“", '"'), + ("”", '"'), + ("‘", "'"), + ("’", "'"), + (":", ":"), + ("、", ","), + ] +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--in-file", + type=str, + required=True, + help="Input cutset.", + ) + + parser.add_argument( + "--out-file", + type=str, + required=True, + help="Output cutset.", + ) + + return parser + + +def normalize_white_spaces(text): + return whiter_space_re.sub(" ", text) + + +def normalize_punctuations(text): + for regex, replacement in punctuations_re: + text = re.sub(regex, replacement, text) + return text + + +def split_text(text: str) -> List[str]: + """ + Example input: '你好呀,You are 一个好人。 去银行存钱?How about you?' + Example output: ['你好', '呀', ',', 'you are', '一个', '好人', '.', '去', '银行', '存钱', '?', 'how about you', '?'] + """ + text = text.lower() + text = normalize_white_spaces(text) + text = normalize_punctuations(text) + ans = [] + + for seg in jieba.cut(text): + if seg in ",.!?:\"'": + ans.append(seg) + elif seg == " " and len(ans) > 0: + if ord("a") <= ord(ans[-1][-1]) <= ord("z"): + ans[-1] += seg + elif ord("a") <= ord(seg[0]) <= ord("z"): + if len(ans) == 0: + ans.append(seg) + continue + + if ans[-1][-1] == " ": + ans[-1] += seg + continue + + ans.append(seg) + else: + ans.append(seg) + + ans = [s.strip() for s in ans] + return ans + + +def main(): + args = get_parser().parse_args() + cuts = load_manifest(args.in_file) + for c in cuts: + assert len(c.supervisions) == 1, (len(c.supervisions), c.supervisions) + text = c.supervisions[0].normalized_text + + text_list = split_text(text) + tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True) + + c.supervisions[0].tokens = tokens + + cuts.to_file(args.out_file) + + print(f"saved to {args.out_file}") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/local/fbank.py b/egs/baker_zh/TTS/local/fbank.py new file mode 120000 index 000000000..5bcf1fde5 --- /dev/null +++ b/egs/baker_zh/TTS/local/fbank.py @@ -0,0 +1 @@ +../matcha/fbank.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/generate_tokens.py b/egs/baker_zh/TTS/local/generate_tokens.py old mode 100644 new mode 100755 index 9d51cbfc7..0f469aaf3 --- a/egs/baker_zh/TTS/local/generate_tokens.py +++ b/egs/baker_zh/TTS/local/generate_tokens.py @@ -46,9 +46,13 @@ def generate_token_list() -> List[str]: ans = list(token_set) ans.sort() + punctuations = list(",.!?:\"'") + ans = punctuations + ans + # use ID 0 for blank - # We use blank for padding + # Use ID 1 of _ for padding ans.insert(0, " ") + ans.insert(1, "_") # return ans diff --git a/egs/baker_zh/TTS/local/validate_manifest.py b/egs/baker_zh/TTS/local/validate_manifest.py new file mode 100755 index 000000000..4e31028f7 --- /dev/null +++ b/egs/baker_zh/TTS/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/baker_zh_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet), type(cut_set) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/baker_zh/TTS/matcha/tokenizer.py b/egs/baker_zh/TTS/matcha/tokenizer.py deleted file mode 120000 index dbe32da2e..000000000 --- a/egs/baker_zh/TTS/matcha/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/matcha/tokenizer.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/tokenizer.py b/egs/baker_zh/TTS/matcha/tokenizer.py new file mode 100644 index 000000000..d5c277ffe --- /dev/null +++ b/egs/baker_zh/TTS/matcha/tokenizer.py @@ -0,0 +1,119 @@ +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import logging +from typing import Dict, List + +import tacotron_cleaner.cleaners + +try: + from piper_phonemize import phonemize_espeak +except Exception as ex: + raise RuntimeError( + f"{ex}\nPlease run\n" + "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html" + ) + +from utils import intersperse + + +# This tokenizer supports both English and Chinese. +# We assume you have used +# ../local/convert_text_to_tokens.py +# to process your text +class Tokenizer(object): + def __init__(self, tokens: str): + """ + Args: + tokens: the file that maps tokens to ids + """ + # Parse token file + self.token2id: Dict[str, int] = {} + with open(tokens, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + id = int(info[0]) + else: + token, id = info[0], int(info[1]) + assert token not in self.token2id, token + self.token2id[token] = id + + # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md + self.pad_id = self.token2id["_"] # padding + self.space_id = self.token2id[" "] # word separator (whitespace) + + self.vocab_size = len(self.token2id) + + def texts_to_token_ids( + self, + sentence_list: List[List[str]], + intersperse_blank: bool = True, + lang: str = "en-us", + ) -> List[List[int]]: + """ + Args: + sentence_list: + A list of sentences. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + lang: + Language argument passed to phonemize_espeak(). + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for sentence in sentence_list: + tokens_list = [] + for word in sentence: + if word in self.token2id: + tokens_list.append(word) + continue + + tmp_tokens_list = phonemize_espeak(word, lang) + for t in tmp_tokens_list: + tokens_list.extend(t) + + token_ids = [] + for t in tokens_list: + if t not in self.token2id: + logging.warning(f"Skip OOV {t}") + continue + + if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id: + continue + + token_ids.append(self.token2id[t]) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.pad_id) + + token_ids_list.append(token_ids) + + return token_ids_list + + +def test_tokenizer(): + import jieba + from pypinyin import lazy_pinyin, Style + + tokenizer = Tokenizer("data/tokens.txt") + text1 = "今天is Monday, tomorrow is 星期二" + text2 = "你好吗? 我很好, how about you?" + + text1 = list(jieba.cut(text1)) + text2 = list(jieba.cut(text2)) + tokens1 = lazy_pinyin(text1, style=Style.TONE3, tone_sandhi=True) + tokens2 = lazy_pinyin(text2, style=Style.TONE3, tone_sandhi=True) + print(tokens1) + print(tokens2) + + ids = tokenizer.texts_to_token_ids([tokens1, tokens2]) + print(ids) + + +if __name__ == "__main__": + test_tokenizer() diff --git a/egs/baker_zh/TTS/matcha/train.py b/egs/baker_zh/TTS/matcha/train.py new file mode 100755 index 000000000..814cd1483 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/train.py @@ -0,0 +1,717 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + + +import argparse +import json +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Union + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.utils import fix_random_seed +from model import fix_len_compatibility +from models.matcha_tts import MatchaTTS +from tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import BakerZhTtsDataModule +from utils import MetricsTracker + +from icefall.checkpoint import load_checkpoint, save_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12335, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_data_statistics(): + return AttributeDict( + { + "mel_mean": 0, + "mel_std": 1, + } + ) + + +def _get_data_params() -> AttributeDict: + params = AttributeDict( + { + "name": "baker-zh", + "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", + "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", + # "batch_size": 64, + # "num_workers": 1, + # "pin_memory": False, + "cleaners": ["english_cleaners2"], + "add_blank": True, + "n_spks": 1, + "n_fft": 1024, + "n_feats": 80, + "sampling_rate": 22050, + "hop_length": 256, + "win_length": 1024, + "f_min": 0, + "f_max": 8000, + "seed": 1234, + "load_durations": False, + "data_statistics": get_data_statistics(), + } + ) + return params + + +def _get_model_params() -> AttributeDict: + n_feats = 80 + filter_channels_dp = 256 + encoder_params_p_dropout = 0.1 + params = AttributeDict( + { + "n_spks": 1, # for baker-zh. + "spk_emb_dim": 64, + "n_feats": n_feats, + "out_size": None, # or use 172 + "prior_loss": True, + "use_precomputed_durations": False, + "data_statistics": get_data_statistics(), + "encoder": AttributeDict( + { + "encoder_type": "RoPE Encoder", # not used + "encoder_params": AttributeDict( + { + "n_feats": n_feats, + "n_channels": 192, + "filter_channels": 768, + "filter_channels_dp": filter_channels_dp, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + "spk_emb_dim": 64, + "n_spks": 1, + "prenet": True, + } + ), + "duration_predictor_params": AttributeDict( + { + "filter_channels_dp": filter_channels_dp, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + } + ), + } + ), + "decoder": AttributeDict( + { + "channels": [256, 256], + "dropout": 0.05, + "attention_head_dim": 64, + "n_blocks": 1, + "num_mid_blocks": 2, + "num_heads": 2, + "act_fn": "snakebeta", + } + ), + "cfm": AttributeDict( + { + "name": "CFM", + "solver": "euler", + "sigma_min": 1e-4, + } + ), + "optimizer": AttributeDict( + { + "lr": 1e-4, + "weight_decay": 0.0, + } + ), + } + ) + + return params + + +def get_params(): + params = AttributeDict( + { + "model_args": _get_model_params(), + "data_args": _get_data_params(), + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 10, + "valid_interval": 1500, + "env_info": get_env_info(), + } + ) + return params + + +def get_model(params): + m = MatchaTTS(**params.model_args) + return m + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params): + """Parse batch data""" + mel_mean = params.data_args.data_statistics.mel_mean + mel_std_inv = 1 / params.data_args.data_statistics.mel_std + for i in range(batch["features"].shape[0]): + n = batch["features_lens"][i] + batch["features"][i : i + 1, :n, :] = ( + batch["features"][i : i + 1, :n, :] - mel_mean + ) * mel_std_inv + batch["features"][i : i + 1, n:, :] = 0 + + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + + tokens = tokenizer.tokens_to_token_ids(tokens, intersperse_blank=True) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + max_feature_length = fix_len_compatibility(features.shape[1]) + if max_feature_length > features.shape[1]: + pad = max_feature_length - features.shape[1] + features = torch.nn.functional.pad(features, (0, 0, 0, pad)) + + # features_lens[features_lens.argmax()] += pad + + return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long() + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) + + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + + batch_size = len(batch["tokens"]) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + # summary stats + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer: Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer=optimizer, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + # audio: (N, T), float32 + # features: (N, T, C), float32 + # audio_lens, (N,), int32 + # features_lens, (N,), int32 + # tokens: List[List[str]], len(tokens) == N + + batch_size = len(batch["tokens"]) + + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) + try: + with autocast(enabled=params.use_fp16): + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + + loss = sum(losses.values()) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. + # The _growth_interval of the grad scaler is configurable, + # but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, " + f"batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + rank=rank, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + "Maximum memory allocated so far is " + f"{torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.pad_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + + logging.info(params) + print(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of parameters: {num_param}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) + + logging.info("About to create datamodule") + + baker_zh = BakerZhTtsDataModule(args) + + train_cuts = baker_zh.train_cuts() + train_dl = baker_zh.train_dataloaders(train_cuts) + + valid_cuts = baker_zh.valid_cuts() + valid_dl = baker_zh.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + fix_random_seed(params.seed + epoch - 1) + if "sampler" in train_dl: + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer=optimizer, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + BakerZhTtsDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/baker_zh/TTS/matcha/tts_datamodule.py b/egs/baker_zh/TTS/matcha/tts_datamodule.py new file mode 100644 index 000000000..d2bdfb96c --- /dev/null +++ b/egs/baker_zh/TTS/matcha/tts_datamodule.py @@ -0,0 +1,340 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from fbank import MatchaFbank, MatchaFbankConfig +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class BakerZhTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + pin_memory=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=True, + pin_memory=True, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz" + ) diff --git a/egs/baker_zh/TTS/prepare.sh b/egs/baker_zh/TTS/prepare.sh index e5fcf0278..e15e3d850 100755 --- a/egs/baker_zh/TTS/prepare.sh +++ b/egs/baker_zh/TTS/prepare.sh @@ -82,3 +82,70 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then python3 ./local/generate_tokens.py --tokens data/tokens.txt fi fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Generate raw cutset" + if [ ! -e data/manifests/baker_zh_cuts_raw.jsonl.gz ]; then + lhotse cut simple \ + -r ./data/manifests/baker_zh_recordings_all.jsonl.gz \ + -s ./data/manifests/baker_zh_supervisions_all.jsonl.gz \ + ./data/manifests/baker_zh_cuts_raw.jsonl.gz + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Convert text to tokens" + if [ ! -e data/manifests/baker_zh_cuts.jsonl.gz ]; then + python3 ./local/convert_text_to_tokens.py \ + --in-file ./data/manifests/baker_zh_cuts_raw.jsonl.gz \ + --out-file ./data/manifests/baker_zh_cuts.jsonl.gz + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate fbank (used by ./matcha)" + mkdir -p data/fbank + if [ ! -e data/fbank/.baker-zh.done ]; then + ./local/compute_fbank_baker_zh.py + touch data/fbank/.baker-zh.done + fi + + if [ ! -e data/fbank/.baker-zh-validated.done ]; then + log "Validating data/fbank for baker-zh (used by ./matcha)" + python3 ./local/validate_manifest.py \ + data/fbank/baker_zh_cuts.jsonl.gz + touch data/fbank/.baker-zh-validated.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Split the baker-zh cuts into train, valid and test sets (used by ./matcha)" + if [ ! -e data/fbank/.baker_zh_split.done ]; then + lhotse subset --last 600 \ + data/fbank/baker_zh_cuts.jsonl.gz \ + data/fbank/baker_zh_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/fbank/baker_zh_cuts_validtest.jsonl.gz \ + data/fbank/baker_zh_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/fbank/baker_zh_cuts_validtest.jsonl.gz \ + data/fbank/baker_zh_cuts_test.jsonl.gz + + rm data/fbank/baker_zh_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/fbank/baker_zh_cuts.jsonl.gz | wc -l) - 600 )) + + lhotse subset --first $n \ + data/fbank/baker_zh_cuts.jsonl.gz \ + data/fbank/baker_zh_cuts_train.jsonl.gz + + touch data/fbank/.baker_zh_split.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 6: Compute fbank mean and std (used by ./matcha)" + if [ ! -f ./data/fbank/cmvn.json ]; then + ./local/compute_fbank_statistics.py ./data/fbank/baker_zh_cuts_train.jsonl.gz ./data/fbank/cmvn.json + fi +fi