mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
ready to train
This commit is contained in:
parent
f4d6fb06aa
commit
e4d40baaf5
1
egs/baker_zh/TTS/local/audio.py
Symbolic link
1
egs/baker_zh/TTS/local/audio.py
Symbolic link
@ -0,0 +1 @@
|
||||
../matcha/audio.py
|
110
egs/baker_zh/TTS/local/compute_fbank_baker_zh.py
Executable file
110
egs/baker_zh/TTS/local/compute_fbank_baker_zh.py
Executable file
@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the baker-zh dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from fbank import MatchaFbank, MatchaFbankConfig
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
||||
from lhotse.audio import RecordingSet
|
||||
from lhotse.supervision import SupervisionSet
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-jobs",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank_baker_zh(num_jobs: int):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
if num_jobs < 1:
|
||||
num_jobs = os.cpu_count()
|
||||
|
||||
logging.info(f"num_jobs: {num_jobs}")
|
||||
logging.info(f"src_dir: {src_dir}")
|
||||
logging.info(f"output_dir: {output_dir}")
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=22050,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
|
||||
prefix = "baker_zh"
|
||||
suffix = "jsonl.gz"
|
||||
|
||||
extractor = MatchaFbank(config)
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
cuts_filename = f"{prefix}_cuts.{suffix}"
|
||||
logging.info(f"Processing {cuts_filename}")
|
||||
cut_set = load_manifest(src_dir / cuts_filename).resample(22050)
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats",
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_parser().parse_args()
|
||||
compute_fbank_baker_zh(args.num_jobs)
|
84
egs/baker_zh/TTS/local/compute_fbank_statistics.py
Executable file
84
egs/baker_zh/TTS/local/compute_fbank_statistics.py
Executable file
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script compute the mean and std of the fbank features.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"manifest",
|
||||
type=Path,
|
||||
help="Path to the manifest file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"cmvn",
|
||||
type=Path,
|
||||
help="Path to the cmvn.json",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
manifest = args.manifest
|
||||
logging.info(
|
||||
f"Computing fbank mean and std for {manifest} and saving to {args.cmvn}"
|
||||
)
|
||||
|
||||
assert manifest.is_file(), f"{manifest} does not exist"
|
||||
cut_set = load_manifest_lazy(manifest)
|
||||
assert isinstance(cut_set, CutSet), type(cut_set)
|
||||
|
||||
feat_dim = cut_set[0].features.num_features
|
||||
num_frames = 0
|
||||
s = 0
|
||||
sq = 0
|
||||
for c in cut_set:
|
||||
f = torch.from_numpy(c.load_features())
|
||||
num_frames += f.shape[0]
|
||||
s += f.sum()
|
||||
sq += f.square().sum()
|
||||
|
||||
fbank_mean = s / (num_frames * feat_dim)
|
||||
fbank_var = sq / (num_frames * feat_dim) - fbank_mean * fbank_mean
|
||||
print("fbank var", fbank_var)
|
||||
fbank_std = fbank_var.sqrt()
|
||||
with open(args.cmvn, "w") as f:
|
||||
json.dump({"fbank_mean": fbank_mean.item(), "fbank_std": fbank_std.item()}, f)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
119
egs/baker_zh/TTS/local/convert_text_to_tokens.py
Executable file
119
egs/baker_zh/TTS/local/convert_text_to_tokens.py
Executable file
@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import jieba
|
||||
from lhotse import load_manifest
|
||||
from pypinyin import lazy_pinyin, load_phrases_dict, Style
|
||||
|
||||
load_phrases_dict(
|
||||
{
|
||||
"行长": [["hang2"], ["zhang3"]],
|
||||
"银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]],
|
||||
}
|
||||
)
|
||||
|
||||
whiter_space_re = re.compile(r"\s+")
|
||||
|
||||
punctuations_re = [
|
||||
(re.compile(x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
(",", ","),
|
||||
("。", "."),
|
||||
("!", "!"),
|
||||
("?", "?"),
|
||||
("“", '"'),
|
||||
("”", '"'),
|
||||
("‘", "'"),
|
||||
("’", "'"),
|
||||
(":", ":"),
|
||||
("、", ","),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--in-file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input cutset.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--out-file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output cutset.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def normalize_white_spaces(text):
|
||||
return whiter_space_re.sub(" ", text)
|
||||
|
||||
|
||||
def normalize_punctuations(text):
|
||||
for regex, replacement in punctuations_re:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def split_text(text: str) -> List[str]:
|
||||
"""
|
||||
Example input: '你好呀,You are 一个好人。 去银行存钱?How about you?'
|
||||
Example output: ['你好', '呀', ',', 'you are', '一个', '好人', '.', '去', '银行', '存钱', '?', 'how about you', '?']
|
||||
"""
|
||||
text = text.lower()
|
||||
text = normalize_white_spaces(text)
|
||||
text = normalize_punctuations(text)
|
||||
ans = []
|
||||
|
||||
for seg in jieba.cut(text):
|
||||
if seg in ",.!?:\"'":
|
||||
ans.append(seg)
|
||||
elif seg == " " and len(ans) > 0:
|
||||
if ord("a") <= ord(ans[-1][-1]) <= ord("z"):
|
||||
ans[-1] += seg
|
||||
elif ord("a") <= ord(seg[0]) <= ord("z"):
|
||||
if len(ans) == 0:
|
||||
ans.append(seg)
|
||||
continue
|
||||
|
||||
if ans[-1][-1] == " ":
|
||||
ans[-1] += seg
|
||||
continue
|
||||
|
||||
ans.append(seg)
|
||||
else:
|
||||
ans.append(seg)
|
||||
|
||||
ans = [s.strip() for s in ans]
|
||||
return ans
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
cuts = load_manifest(args.in_file)
|
||||
for c in cuts:
|
||||
assert len(c.supervisions) == 1, (len(c.supervisions), c.supervisions)
|
||||
text = c.supervisions[0].normalized_text
|
||||
|
||||
text_list = split_text(text)
|
||||
tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True)
|
||||
|
||||
c.supervisions[0].tokens = tokens
|
||||
|
||||
cuts.to_file(args.out_file)
|
||||
|
||||
print(f"saved to {args.out_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/baker_zh/TTS/local/fbank.py
Symbolic link
1
egs/baker_zh/TTS/local/fbank.py
Symbolic link
@ -0,0 +1 @@
|
||||
../matcha/fbank.py
|
6
egs/baker_zh/TTS/local/generate_tokens.py
Normal file → Executable file
6
egs/baker_zh/TTS/local/generate_tokens.py
Normal file → Executable file
@ -46,9 +46,13 @@ def generate_token_list() -> List[str]:
|
||||
ans = list(token_set)
|
||||
ans.sort()
|
||||
|
||||
punctuations = list(",.!?:\"'")
|
||||
ans = punctuations + ans
|
||||
|
||||
# use ID 0 for blank
|
||||
# We use blank for padding
|
||||
# Use ID 1 of _ for padding
|
||||
ans.insert(0, " ")
|
||||
ans.insert(1, "_") #
|
||||
|
||||
return ans
|
||||
|
||||
|
70
egs/baker_zh/TTS/local/validate_manifest.py
Executable file
70
egs/baker_zh/TTS/local/validate_manifest.py
Executable file
@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script checks the following assumptions of the generated manifest:
|
||||
|
||||
- Single supervision per cut
|
||||
|
||||
We will add more checks later if needed.
|
||||
|
||||
Usage example:
|
||||
|
||||
python3 ./local/validate_manifest.py \
|
||||
./data/spectrogram/baker_zh_cuts_all.jsonl.gz
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset.speech_synthesis import validate_for_tts
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"manifest",
|
||||
type=Path,
|
||||
help="Path to the manifest file",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
manifest = args.manifest
|
||||
logging.info(f"Validating {manifest}")
|
||||
|
||||
assert manifest.is_file(), f"{manifest} does not exist"
|
||||
cut_set = load_manifest_lazy(manifest)
|
||||
assert isinstance(cut_set, CutSet), type(cut_set)
|
||||
|
||||
validate_for_tts(cut_set)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/matcha/tokenizer.py
|
119
egs/baker_zh/TTS/matcha/tokenizer.py
Normal file
119
egs/baker_zh/TTS/matcha/tokenizer.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
import tacotron_cleaner.cleaners
|
||||
|
||||
try:
|
||||
from piper_phonemize import phonemize_espeak
|
||||
except Exception as ex:
|
||||
raise RuntimeError(
|
||||
f"{ex}\nPlease run\n"
|
||||
"pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
|
||||
)
|
||||
|
||||
from utils import intersperse
|
||||
|
||||
|
||||
# This tokenizer supports both English and Chinese.
|
||||
# We assume you have used
|
||||
# ../local/convert_text_to_tokens.py
|
||||
# to process your text
|
||||
class Tokenizer(object):
|
||||
def __init__(self, tokens: str):
|
||||
"""
|
||||
Args:
|
||||
tokens: the file that maps tokens to ids
|
||||
"""
|
||||
# Parse token file
|
||||
self.token2id: Dict[str, int] = {}
|
||||
with open(tokens, "r", encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
info = line.rstrip().split()
|
||||
if len(info) == 1:
|
||||
# case of space
|
||||
token = " "
|
||||
id = int(info[0])
|
||||
else:
|
||||
token, id = info[0], int(info[1])
|
||||
assert token not in self.token2id, token
|
||||
self.token2id[token] = id
|
||||
|
||||
# Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
|
||||
self.pad_id = self.token2id["_"] # padding
|
||||
self.space_id = self.token2id[" "] # word separator (whitespace)
|
||||
|
||||
self.vocab_size = len(self.token2id)
|
||||
|
||||
def texts_to_token_ids(
|
||||
self,
|
||||
sentence_list: List[List[str]],
|
||||
intersperse_blank: bool = True,
|
||||
lang: str = "en-us",
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Args:
|
||||
sentence_list:
|
||||
A list of sentences.
|
||||
intersperse_blank:
|
||||
Whether to intersperse blanks in the token sequence.
|
||||
lang:
|
||||
Language argument passed to phonemize_espeak().
|
||||
|
||||
Returns:
|
||||
Return a list of token id list [utterance][token_id]
|
||||
"""
|
||||
token_ids_list = []
|
||||
|
||||
for sentence in sentence_list:
|
||||
tokens_list = []
|
||||
for word in sentence:
|
||||
if word in self.token2id:
|
||||
tokens_list.append(word)
|
||||
continue
|
||||
|
||||
tmp_tokens_list = phonemize_espeak(word, lang)
|
||||
for t in tmp_tokens_list:
|
||||
tokens_list.extend(t)
|
||||
|
||||
token_ids = []
|
||||
for t in tokens_list:
|
||||
if t not in self.token2id:
|
||||
logging.warning(f"Skip OOV {t}")
|
||||
continue
|
||||
|
||||
if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id:
|
||||
continue
|
||||
|
||||
token_ids.append(self.token2id[t])
|
||||
|
||||
if intersperse_blank:
|
||||
token_ids = intersperse(token_ids, self.pad_id)
|
||||
|
||||
token_ids_list.append(token_ids)
|
||||
|
||||
return token_ids_list
|
||||
|
||||
|
||||
def test_tokenizer():
|
||||
import jieba
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
|
||||
tokenizer = Tokenizer("data/tokens.txt")
|
||||
text1 = "今天is Monday, tomorrow is 星期二"
|
||||
text2 = "你好吗? 我很好, how about you?"
|
||||
|
||||
text1 = list(jieba.cut(text1))
|
||||
text2 = list(jieba.cut(text2))
|
||||
tokens1 = lazy_pinyin(text1, style=Style.TONE3, tone_sandhi=True)
|
||||
tokens2 = lazy_pinyin(text2, style=Style.TONE3, tone_sandhi=True)
|
||||
print(tokens1)
|
||||
print(tokens2)
|
||||
|
||||
ids = tokenizer.texts_to_token_ids([tokens1, tokens2])
|
||||
print(ids)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tokenizer()
|
717
egs/baker_zh/TTS/matcha/train.py
Executable file
717
egs/baker_zh/TTS/matcha/train.py
Executable file
@ -0,0 +1,717 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import fix_len_compatibility
|
||||
from models.matcha_tts import MatchaTTS
|
||||
from tokenizer import Tokenizer
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tts_datamodule import BakerZhTtsDataModule
|
||||
from utils import MetricsTracker
|
||||
|
||||
from icefall.checkpoint import load_checkpoint, save_checkpoint
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12335,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Resume training from this epoch. It should be positive.
|
||||
If larger than 1, it will load checkpoint from
|
||||
exp-dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default="matcha/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/tokens.txt",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cmvn",
|
||||
type=str,
|
||||
default="data/fbank/cmvn.json",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=10,
|
||||
help="""Save checkpoint after processing this number of epochs"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.cur_epoch % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
|
||||
Since it will take around 1000 epochs, we suggest using a large
|
||||
save_every_n to save disk space.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_data_statistics():
|
||||
return AttributeDict(
|
||||
{
|
||||
"mel_mean": 0,
|
||||
"mel_std": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_data_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"name": "baker-zh",
|
||||
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
||||
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
||||
# "batch_size": 64,
|
||||
# "num_workers": 1,
|
||||
# "pin_memory": False,
|
||||
"cleaners": ["english_cleaners2"],
|
||||
"add_blank": True,
|
||||
"n_spks": 1,
|
||||
"n_fft": 1024,
|
||||
"n_feats": 80,
|
||||
"sampling_rate": 22050,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"f_min": 0,
|
||||
"f_max": 8000,
|
||||
"seed": 1234,
|
||||
"load_durations": False,
|
||||
"data_statistics": get_data_statistics(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def _get_model_params() -> AttributeDict:
|
||||
n_feats = 80
|
||||
filter_channels_dp = 256
|
||||
encoder_params_p_dropout = 0.1
|
||||
params = AttributeDict(
|
||||
{
|
||||
"n_spks": 1, # for baker-zh.
|
||||
"spk_emb_dim": 64,
|
||||
"n_feats": n_feats,
|
||||
"out_size": None, # or use 172
|
||||
"prior_loss": True,
|
||||
"use_precomputed_durations": False,
|
||||
"data_statistics": get_data_statistics(),
|
||||
"encoder": AttributeDict(
|
||||
{
|
||||
"encoder_type": "RoPE Encoder", # not used
|
||||
"encoder_params": AttributeDict(
|
||||
{
|
||||
"n_feats": n_feats,
|
||||
"n_channels": 192,
|
||||
"filter_channels": 768,
|
||||
"filter_channels_dp": filter_channels_dp,
|
||||
"n_heads": 2,
|
||||
"n_layers": 6,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": encoder_params_p_dropout,
|
||||
"spk_emb_dim": 64,
|
||||
"n_spks": 1,
|
||||
"prenet": True,
|
||||
}
|
||||
),
|
||||
"duration_predictor_params": AttributeDict(
|
||||
{
|
||||
"filter_channels_dp": filter_channels_dp,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": encoder_params_p_dropout,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
"decoder": AttributeDict(
|
||||
{
|
||||
"channels": [256, 256],
|
||||
"dropout": 0.05,
|
||||
"attention_head_dim": 64,
|
||||
"n_blocks": 1,
|
||||
"num_mid_blocks": 2,
|
||||
"num_heads": 2,
|
||||
"act_fn": "snakebeta",
|
||||
}
|
||||
),
|
||||
"cfm": AttributeDict(
|
||||
{
|
||||
"name": "CFM",
|
||||
"solver": "euler",
|
||||
"sigma_min": 1e-4,
|
||||
}
|
||||
),
|
||||
"optimizer": AttributeDict(
|
||||
{
|
||||
"lr": 1e-4,
|
||||
"weight_decay": 0.0,
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def get_params():
|
||||
params = AttributeDict(
|
||||
{
|
||||
"model_args": _get_model_params(),
|
||||
"data_args": _get_data_params(),
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": -1, # 0
|
||||
"log_interval": 10,
|
||||
"valid_interval": 1500,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def get_model(params):
|
||||
m = MatchaTTS(**params.model_args)
|
||||
return m
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict, model: nn.Module
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
Returns:
|
||||
Return a dict containing previously saved training info.
|
||||
"""
|
||||
if params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
saved_params = load_checkpoint(filename, model=model)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params):
|
||||
"""Parse batch data"""
|
||||
mel_mean = params.data_args.data_statistics.mel_mean
|
||||
mel_std_inv = 1 / params.data_args.data_statistics.mel_std
|
||||
for i in range(batch["features"].shape[0]):
|
||||
n = batch["features_lens"][i]
|
||||
batch["features"][i : i + 1, :n, :] = (
|
||||
batch["features"][i : i + 1, :n, :] - mel_mean
|
||||
) * mel_std_inv
|
||||
batch["features"][i : i + 1, n:, :] = 0
|
||||
|
||||
audio = batch["audio"].to(device)
|
||||
features = batch["features"].to(device)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
tokens = batch["tokens"]
|
||||
|
||||
tokens = tokenizer.tokens_to_token_ids(tokens, intersperse_blank=True)
|
||||
tokens = k2.RaggedTensor(tokens)
|
||||
row_splits = tokens.shape.row_splits(1)
|
||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||
tokens = tokens.to(device)
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
# a tensor of shape (B, T)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
||||
|
||||
max_feature_length = fix_len_compatibility(features.shape[1])
|
||||
if max_feature_length > features.shape[1]:
|
||||
pad = max_feature_length - features.shape[1]
|
||||
features = torch.nn.functional.pad(features, (0, 0, 0, pad))
|
||||
|
||||
# features_lens[features_lens.argmax()] += pad
|
||||
|
||||
return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long()
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
tokenizer: Tokenizer,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
|
||||
|
||||
# used to summary the stats over iterations
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
) = prepare_input(batch, tokenizer, device, params)
|
||||
|
||||
losses = get_losses(
|
||||
{
|
||||
"x": tokens,
|
||||
"x_lengths": tokens_lens,
|
||||
"y": features.permute(0, 2, 1),
|
||||
"y_lengths": features_lens,
|
||||
"spks": None, # should change it for multi-speakers
|
||||
"durations": None,
|
||||
}
|
||||
)
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
s = 0
|
||||
|
||||
for key, value in losses.items():
|
||||
v = value.detach().item()
|
||||
loss_info[key] = v * batch_size
|
||||
s += v * batch_size
|
||||
|
||||
loss_info["tot_loss"] = s
|
||||
|
||||
# summary stats
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
|
||||
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
tokenizer: Tokenizer,
|
||||
optimizer: Optimizer,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
"""
|
||||
model.train()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
|
||||
|
||||
# used to track the stats over iterations in one epoch
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
saved_bad_model = False
|
||||
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint(
|
||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scaler=scaler,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
# audio: (N, T), float32
|
||||
# features: (N, T, C), float32
|
||||
# audio_lens, (N,), int32
|
||||
# features_lens, (N,), int32
|
||||
# tokens: List[List[str]], len(tokens) == N
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
) = prepare_input(batch, tokenizer, device, params)
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
losses = get_losses(
|
||||
{
|
||||
"x": tokens,
|
||||
"x_lengths": tokens_lens,
|
||||
"y": features.permute(0, 2, 1),
|
||||
"y_lengths": features_lens,
|
||||
"spks": None, # should change it for multi-speakers
|
||||
"durations": None,
|
||||
}
|
||||
)
|
||||
|
||||
loss = sum(losses.values())
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
s = 0
|
||||
|
||||
for key, value in losses.items():
|
||||
v = value.detach().item()
|
||||
loss_info[key] = v * batch_size
|
||||
s += v * batch_size
|
||||
|
||||
loss_info["tot_loss"] = s
|
||||
|
||||
tot_loss = tot_loss + loss_info
|
||||
except: # noqa
|
||||
save_bad_model()
|
||||
raise
|
||||
|
||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it.
|
||||
# The _growth_interval of the grad scaler is configurable,
|
||||
# but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
if cur_grad_scale < 8.0 or (
|
||||
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
|
||||
):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
save_bad_model(suffix="-first-warning")
|
||||
saved_bad_model = True
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
|
||||
if params.batch_idx_train % params.log_interval == 0:
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"global_batch_idx: {params.batch_idx_train}, "
|
||||
f"batch size: {batch_size}, "
|
||||
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
|
||||
if params.batch_idx_train % params.valid_interval == 1:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
"Maximum memory allocated so far is "
|
||||
f"{torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.pad_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
|
||||
with open(params.cmvn) as f:
|
||||
stats = json.load(f)
|
||||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
logging.info(params)
|
||||
print(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of parameters: {num_param}")
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
|
||||
|
||||
logging.info("About to create datamodule")
|
||||
|
||||
baker_zh = BakerZhTtsDataModule(args)
|
||||
|
||||
train_cuts = baker_zh.train_cuts()
|
||||
train_dl = baker_zh.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = baker_zh.valid_cuts()
|
||||
valid_dl = baker_zh.valid_dataloaders(valid_cuts)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
logging.info(f"Start epoch {epoch}")
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
if "sampler" in train_dl:
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
optimizer=optimizer,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
if rank == 0:
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
BakerZhTtsDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
340
egs/baker_zh/TTS/matcha/tts_datamodule.py
Normal file
340
egs/baker_zh/TTS/matcha/tts_datamodule.py
Normal file
@ -0,0 +1,340 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from fbank import MatchaFbank, MatchaFbankConfig
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class BakerZhTtsDataModule:
|
||||
"""
|
||||
DataModule for tts experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="TTS data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
num_buckets=self.args.num_buckets,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create valid dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MatchaFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
test_sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
num_buckets=self.args.num_buckets,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=test_sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get validation cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz"
|
||||
)
|
@ -82,3 +82,70 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
python3 ./local/generate_tokens.py --tokens data/tokens.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Generate raw cutset"
|
||||
if [ ! -e data/manifests/baker_zh_cuts_raw.jsonl.gz ]; then
|
||||
lhotse cut simple \
|
||||
-r ./data/manifests/baker_zh_recordings_all.jsonl.gz \
|
||||
-s ./data/manifests/baker_zh_supervisions_all.jsonl.gz \
|
||||
./data/manifests/baker_zh_cuts_raw.jsonl.gz
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Convert text to tokens"
|
||||
if [ ! -e data/manifests/baker_zh_cuts.jsonl.gz ]; then
|
||||
python3 ./local/convert_text_to_tokens.py \
|
||||
--in-file ./data/manifests/baker_zh_cuts_raw.jsonl.gz \
|
||||
--out-file ./data/manifests/baker_zh_cuts.jsonl.gz
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Generate fbank (used by ./matcha)"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.baker-zh.done ]; then
|
||||
./local/compute_fbank_baker_zh.py
|
||||
touch data/fbank/.baker-zh.done
|
||||
fi
|
||||
|
||||
if [ ! -e data/fbank/.baker-zh-validated.done ]; then
|
||||
log "Validating data/fbank for baker-zh (used by ./matcha)"
|
||||
python3 ./local/validate_manifest.py \
|
||||
data/fbank/baker_zh_cuts.jsonl.gz
|
||||
touch data/fbank/.baker-zh-validated.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Split the baker-zh cuts into train, valid and test sets (used by ./matcha)"
|
||||
if [ ! -e data/fbank/.baker_zh_split.done ]; then
|
||||
lhotse subset --last 600 \
|
||||
data/fbank/baker_zh_cuts.jsonl.gz \
|
||||
data/fbank/baker_zh_cuts_validtest.jsonl.gz
|
||||
lhotse subset --first 100 \
|
||||
data/fbank/baker_zh_cuts_validtest.jsonl.gz \
|
||||
data/fbank/baker_zh_cuts_valid.jsonl.gz
|
||||
lhotse subset --last 500 \
|
||||
data/fbank/baker_zh_cuts_validtest.jsonl.gz \
|
||||
data/fbank/baker_zh_cuts_test.jsonl.gz
|
||||
|
||||
rm data/fbank/baker_zh_cuts_validtest.jsonl.gz
|
||||
|
||||
n=$(( $(gunzip -c data/fbank/baker_zh_cuts.jsonl.gz | wc -l) - 600 ))
|
||||
|
||||
lhotse subset --first $n \
|
||||
data/fbank/baker_zh_cuts.jsonl.gz \
|
||||
data/fbank/baker_zh_cuts_train.jsonl.gz
|
||||
|
||||
touch data/fbank/.baker_zh_split.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 6: Compute fbank mean and std (used by ./matcha)"
|
||||
if [ ! -f ./data/fbank/cmvn.json ]; then
|
||||
./local/compute_fbank_statistics.py ./data/fbank/baker_zh_cuts_train.jsonl.gz ./data/fbank/cmvn.json
|
||||
fi
|
||||
fi
|
||||
|
Loading…
x
Reference in New Issue
Block a user