diff --git a/egs/baker_zh/TTS/README.md b/egs/baker_zh/TTS/README.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/baker_zh/TTS/local/README.md b/egs/baker_zh/TTS/local/README.md deleted file mode 100644 index dac138853..000000000 --- a/egs/baker_zh/TTS/local/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Introduction - -[./symbols.py](./symbols.py) is copied from -https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py - -[./pypinyin-local.dict](./pypinyin-local.dict) is copied from -https://github.com/UEhQZXI/vits_chinese/blob/master/misc/pypinyin-local.dict diff --git a/egs/baker_zh/TTS/local/__init__.py b/egs/baker_zh/TTS/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/baker_zh/TTS/local/compute_spectrogram_baker.py b/egs/baker_zh/TTS/local/compute_spectrogram_baker.py deleted file mode 100755 index 1a15c7c0d..000000000 --- a/egs/baker_zh/TTS/local/compute_spectrogram_baker.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the baker_zh dataset. -It looks for manifests in the directory data/manifests. - -The generated spectrogram features are saved in data/spectrogram. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - LilcomChunkyWriter, - Spectrogram, - SpectrogramConfig, - load_manifest, -) -from lhotse.audio import RecordingSet -from lhotse.supervision import SupervisionSet - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_spectrogram_baker_zh(): - src_dir = Path("data/manifests") - output_dir = Path("data/spectrogram") - num_jobs = min(4, os.cpu_count()) - - sampling_rate = 48000 - frame_length = 1024 / sampling_rate # (in second) - frame_shift = 256 / sampling_rate # (in second) - use_fft_mag = True - - prefix = "baker_zh" - suffix = "jsonl.gz" - partition = "all" - - recordings = load_manifest( - src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet - ) - supervisions = load_manifest( - src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet - ) - - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=frame_length, - frame_shift=frame_shift, - use_fft_mag=use_fft_mag, - ) - extractor = Spectrogram(config) - - with get_executor() as ex: # Initialize the executor only once. - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{cuts_filename} already exists - skipping.") - return - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=recordings, supervisions=supervisions - ) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_spectrogram_baker_zh() diff --git a/egs/baker_zh/TTS/local/pinyin_dict.py b/egs/baker_zh/TTS/local/pinyin_dict.py deleted file mode 100644 index 950fb39fc..000000000 --- a/egs/baker_zh/TTS/local/pinyin_dict.py +++ /dev/null @@ -1,421 +0,0 @@ -# This dict is copied from -# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py -pinyin_dict = { - "a": ("^", "a"), - "ai": ("^", "ai"), - "an": ("^", "an"), - "ang": ("^", "ang"), - "ao": ("^", "ao"), - "ba": ("b", "a"), - "bai": ("b", "ai"), - "ban": ("b", "an"), - "bang": ("b", "ang"), - "bao": ("b", "ao"), - "be": ("b", "e"), - "bei": ("b", "ei"), - "ben": ("b", "en"), - "beng": ("b", "eng"), - "bi": ("b", "i"), - "bian": ("b", "ian"), - "biao": ("b", "iao"), - "bie": ("b", "ie"), - "bin": ("b", "in"), - "bing": ("b", "ing"), - "bo": ("b", "o"), - "bu": ("b", "u"), - "ca": ("c", "a"), - "cai": ("c", "ai"), - "can": ("c", "an"), - "cang": ("c", "ang"), - "cao": ("c", "ao"), - "ce": ("c", "e"), - "cen": ("c", "en"), - "ceng": ("c", "eng"), - "cha": ("ch", "a"), - "chai": ("ch", "ai"), - "chan": ("ch", "an"), - "chang": ("ch", "ang"), - "chao": ("ch", "ao"), - "che": ("ch", "e"), - "chen": ("ch", "en"), - "cheng": ("ch", "eng"), - "chi": ("ch", "iii"), - "chong": ("ch", "ong"), - "chou": ("ch", "ou"), - "chu": ("ch", "u"), - "chua": ("ch", "ua"), - "chuai": ("ch", "uai"), - "chuan": ("ch", "uan"), - "chuang": ("ch", "uang"), - "chui": ("ch", "uei"), - "chun": ("ch", "uen"), - "chuo": ("ch", "uo"), - "ci": ("c", "ii"), - "cong": ("c", "ong"), - "cou": ("c", "ou"), - "cu": ("c", "u"), - "cuan": ("c", "uan"), - "cui": ("c", "uei"), - "cun": ("c", "uen"), - "cuo": ("c", "uo"), - "da": ("d", "a"), - "dai": ("d", "ai"), - "dan": ("d", "an"), - "dang": ("d", "ang"), - "dao": ("d", "ao"), - "de": ("d", "e"), - "dei": ("d", "ei"), - "den": ("d", "en"), - "deng": ("d", "eng"), - "di": ("d", "i"), - "dia": ("d", "ia"), - "dian": ("d", "ian"), - "diao": ("d", "iao"), - "die": ("d", "ie"), - "ding": ("d", "ing"), - "diu": ("d", "iou"), - "dong": ("d", "ong"), - "dou": ("d", "ou"), - "du": ("d", "u"), - "duan": ("d", "uan"), - "dui": ("d", "uei"), - "dun": ("d", "uen"), - "duo": ("d", "uo"), - "e": ("^", "e"), - "ei": ("^", "ei"), - "en": ("^", "en"), - "ng": ("^", "en"), - "eng": ("^", "eng"), - "er": ("^", "er"), - "fa": ("f", "a"), - "fan": ("f", "an"), - "fang": ("f", "ang"), - "fei": ("f", "ei"), - "fen": ("f", "en"), - "feng": ("f", "eng"), - "fo": ("f", "o"), - "fou": ("f", "ou"), - "fu": ("f", "u"), - "ga": ("g", "a"), - "gai": ("g", "ai"), - "gan": ("g", "an"), - "gang": ("g", "ang"), - "gao": ("g", "ao"), - "ge": ("g", "e"), - "gei": ("g", "ei"), - "gen": ("g", "en"), - "geng": ("g", "eng"), - "gong": ("g", "ong"), - "gou": ("g", "ou"), - "gu": ("g", "u"), - "gua": ("g", "ua"), - "guai": ("g", "uai"), - "guan": ("g", "uan"), - "guang": ("g", "uang"), - "gui": ("g", "uei"), - "gun": ("g", "uen"), - "guo": ("g", "uo"), - "ha": ("h", "a"), - "hai": ("h", "ai"), - "han": ("h", "an"), - "hang": ("h", "ang"), - "hao": ("h", "ao"), - "he": ("h", "e"), - "hei": ("h", "ei"), - "hen": ("h", "en"), - "heng": ("h", "eng"), - "hong": ("h", "ong"), - "hou": ("h", "ou"), - "hu": ("h", "u"), - "hua": ("h", "ua"), - "huai": ("h", "uai"), - "huan": ("h", "uan"), - "huang": ("h", "uang"), - "hui": ("h", "uei"), - "hun": ("h", "uen"), - "huo": ("h", "uo"), - "ji": ("j", "i"), - "jia": ("j", "ia"), - "jian": ("j", "ian"), - "jiang": ("j", "iang"), - "jiao": ("j", "iao"), - "jie": ("j", "ie"), - "jin": ("j", "in"), - "jing": ("j", "ing"), - "jiong": ("j", "iong"), - "jiu": ("j", "iou"), - "ju": ("j", "v"), - "juan": ("j", "van"), - "jue": ("j", "ve"), - "jun": ("j", "vn"), - "ka": ("k", "a"), - "kai": ("k", "ai"), - "kan": ("k", "an"), - "kang": ("k", "ang"), - "kao": ("k", "ao"), - "ke": ("k", "e"), - "kei": ("k", "ei"), - "ken": ("k", "en"), - "keng": ("k", "eng"), - "kong": ("k", "ong"), - "kou": ("k", "ou"), - "ku": ("k", "u"), - "kua": ("k", "ua"), - "kuai": ("k", "uai"), - "kuan": ("k", "uan"), - "kuang": ("k", "uang"), - "kui": ("k", "uei"), - "kun": ("k", "uen"), - "kuo": ("k", "uo"), - "la": ("l", "a"), - "lai": ("l", "ai"), - "lan": ("l", "an"), - "lang": ("l", "ang"), - "lao": ("l", "ao"), - "le": ("l", "e"), - "lei": ("l", "ei"), - "leng": ("l", "eng"), - "li": ("l", "i"), - "lia": ("l", "ia"), - "lian": ("l", "ian"), - "liang": ("l", "iang"), - "liao": ("l", "iao"), - "lie": ("l", "ie"), - "lin": ("l", "in"), - "ling": ("l", "ing"), - "liu": ("l", "iou"), - "lo": ("l", "o"), - "long": ("l", "ong"), - "lou": ("l", "ou"), - "lu": ("l", "u"), - "lv": ("l", "v"), - "luan": ("l", "uan"), - "lve": ("l", "ve"), - "lue": ("l", "ve"), - "lun": ("l", "uen"), - "luo": ("l", "uo"), - "ma": ("m", "a"), - "mai": ("m", "ai"), - "man": ("m", "an"), - "mang": ("m", "ang"), - "mao": ("m", "ao"), - "me": ("m", "e"), - "mei": ("m", "ei"), - "men": ("m", "en"), - "meng": ("m", "eng"), - "mi": ("m", "i"), - "mian": ("m", "ian"), - "miao": ("m", "iao"), - "mie": ("m", "ie"), - "min": ("m", "in"), - "ming": ("m", "ing"), - "miu": ("m", "iou"), - "mo": ("m", "o"), - "mou": ("m", "ou"), - "mu": ("m", "u"), - "na": ("n", "a"), - "nai": ("n", "ai"), - "nan": ("n", "an"), - "nang": ("n", "ang"), - "nao": ("n", "ao"), - "ne": ("n", "e"), - "nei": ("n", "ei"), - "nen": ("n", "en"), - "neng": ("n", "eng"), - "ni": ("n", "i"), - "nia": ("n", "ia"), - "nian": ("n", "ian"), - "niang": ("n", "iang"), - "niao": ("n", "iao"), - "nie": ("n", "ie"), - "nin": ("n", "in"), - "ning": ("n", "ing"), - "niu": ("n", "iou"), - "nong": ("n", "ong"), - "nou": ("n", "ou"), - "nu": ("n", "u"), - "nv": ("n", "v"), - "nuan": ("n", "uan"), - "nve": ("n", "ve"), - "nue": ("n", "ve"), - "nuo": ("n", "uo"), - "o": ("^", "o"), - "ou": ("^", "ou"), - "pa": ("p", "a"), - "pai": ("p", "ai"), - "pan": ("p", "an"), - "pang": ("p", "ang"), - "pao": ("p", "ao"), - "pe": ("p", "e"), - "pei": ("p", "ei"), - "pen": ("p", "en"), - "peng": ("p", "eng"), - "pi": ("p", "i"), - "pian": ("p", "ian"), - "piao": ("p", "iao"), - "pie": ("p", "ie"), - "pin": ("p", "in"), - "ping": ("p", "ing"), - "po": ("p", "o"), - "pou": ("p", "ou"), - "pu": ("p", "u"), - "qi": ("q", "i"), - "qia": ("q", "ia"), - "qian": ("q", "ian"), - "qiang": ("q", "iang"), - "qiao": ("q", "iao"), - "qie": ("q", "ie"), - "qin": ("q", "in"), - "qing": ("q", "ing"), - "qiong": ("q", "iong"), - "qiu": ("q", "iou"), - "qu": ("q", "v"), - "quan": ("q", "van"), - "que": ("q", "ve"), - "qun": ("q", "vn"), - "ran": ("r", "an"), - "rang": ("r", "ang"), - "rao": ("r", "ao"), - "re": ("r", "e"), - "ren": ("r", "en"), - "reng": ("r", "eng"), - "ri": ("r", "iii"), - "rong": ("r", "ong"), - "rou": ("r", "ou"), - "ru": ("r", "u"), - "rua": ("r", "ua"), - "ruan": ("r", "uan"), - "rui": ("r", "uei"), - "run": ("r", "uen"), - "ruo": ("r", "uo"), - "sa": ("s", "a"), - "sai": ("s", "ai"), - "san": ("s", "an"), - "sang": ("s", "ang"), - "sao": ("s", "ao"), - "se": ("s", "e"), - "sen": ("s", "en"), - "seng": ("s", "eng"), - "sha": ("sh", "a"), - "shai": ("sh", "ai"), - "shan": ("sh", "an"), - "shang": ("sh", "ang"), - "shao": ("sh", "ao"), - "she": ("sh", "e"), - "shei": ("sh", "ei"), - "shen": ("sh", "en"), - "sheng": ("sh", "eng"), - "shi": ("sh", "iii"), - "shou": ("sh", "ou"), - "shu": ("sh", "u"), - "shua": ("sh", "ua"), - "shuai": ("sh", "uai"), - "shuan": ("sh", "uan"), - "shuang": ("sh", "uang"), - "shui": ("sh", "uei"), - "shun": ("sh", "uen"), - "shuo": ("sh", "uo"), - "si": ("s", "ii"), - "song": ("s", "ong"), - "sou": ("s", "ou"), - "su": ("s", "u"), - "suan": ("s", "uan"), - "sui": ("s", "uei"), - "sun": ("s", "uen"), - "suo": ("s", "uo"), - "ta": ("t", "a"), - "tai": ("t", "ai"), - "tan": ("t", "an"), - "tang": ("t", "ang"), - "tao": ("t", "ao"), - "te": ("t", "e"), - "tei": ("t", "ei"), - "teng": ("t", "eng"), - "ti": ("t", "i"), - "tian": ("t", "ian"), - "tiao": ("t", "iao"), - "tie": ("t", "ie"), - "ting": ("t", "ing"), - "tong": ("t", "ong"), - "tou": ("t", "ou"), - "tu": ("t", "u"), - "tuan": ("t", "uan"), - "tui": ("t", "uei"), - "tun": ("t", "uen"), - "tuo": ("t", "uo"), - "wa": ("^", "ua"), - "wai": ("^", "uai"), - "wan": ("^", "uan"), - "wang": ("^", "uang"), - "wei": ("^", "uei"), - "wen": ("^", "uen"), - "weng": ("^", "ueng"), - "wo": ("^", "uo"), - "wu": ("^", "u"), - "xi": ("x", "i"), - "xia": ("x", "ia"), - "xian": ("x", "ian"), - "xiang": ("x", "iang"), - "xiao": ("x", "iao"), - "xie": ("x", "ie"), - "xin": ("x", "in"), - "xing": ("x", "ing"), - "xiong": ("x", "iong"), - "xiu": ("x", "iou"), - "xu": ("x", "v"), - "xuan": ("x", "van"), - "xue": ("x", "ve"), - "xun": ("x", "vn"), - "ya": ("^", "ia"), - "yan": ("^", "ian"), - "yang": ("^", "iang"), - "yao": ("^", "iao"), - "ye": ("^", "ie"), - "yi": ("^", "i"), - "yin": ("^", "in"), - "ying": ("^", "ing"), - "yo": ("^", "iou"), - "yong": ("^", "iong"), - "you": ("^", "iou"), - "yu": ("^", "v"), - "yuan": ("^", "van"), - "yue": ("^", "ve"), - "yun": ("^", "vn"), - "za": ("z", "a"), - "zai": ("z", "ai"), - "zan": ("z", "an"), - "zang": ("z", "ang"), - "zao": ("z", "ao"), - "ze": ("z", "e"), - "zei": ("z", "ei"), - "zen": ("z", "en"), - "zeng": ("z", "eng"), - "zha": ("zh", "a"), - "zhai": ("zh", "ai"), - "zhan": ("zh", "an"), - "zhang": ("zh", "ang"), - "zhao": ("zh", "ao"), - "zhe": ("zh", "e"), - "zhei": ("zh", "ei"), - "zhen": ("zh", "en"), - "zheng": ("zh", "eng"), - "zhi": ("zh", "iii"), - "zhong": ("zh", "ong"), - "zhou": ("zh", "ou"), - "zhu": ("zh", "u"), - "zhua": ("zh", "ua"), - "zhuai": ("zh", "uai"), - "zhuan": ("zh", "uan"), - "zhuang": ("zh", "uang"), - "zhui": ("zh", "uei"), - "zhun": ("zh", "uen"), - "zhuo": ("zh", "uo"), - "zi": ("z", "ii"), - "zong": ("z", "ong"), - "zou": ("z", "ou"), - "zu": ("z", "u"), - "zuan": ("z", "uan"), - "zui": ("z", "uei"), - "zun": ("z", "uen"), - "zuo": ("z", "uo"), -} diff --git a/egs/baker_zh/TTS/local/prepare_token_file.py b/egs/baker_zh/TTS/local/prepare_token_file.py deleted file mode 100755 index d90910ab0..000000000 --- a/egs/baker_zh/TTS/local/prepare_token_file.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file generates the file that maps tokens to IDs. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict -from symbols import symbols - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--tokens", - type=Path, - default=Path("data/tokens.txt"), - help="Path to the dict that maps the text tokens to IDs", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - tokens = Path(args.tokens) - - with open(tokens, "w", encoding="utf-8") as f: - for token_id, token in enumerate(symbols): - f.write(f"{token} {token_id}\n") - - -if __name__ == "__main__": - main() diff --git a/egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py b/egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py deleted file mode 100755 index 0b27fd1e9..000000000 --- a/egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file reads the texts in given manifest and save the new cuts with tokens. -""" - -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest - -from tokenizer import Tokenizer - - -def prepare_tokens_baker_zh(): - output_dir = Path("data/spectrogram") - prefix = "baker_zh" - suffix = "jsonl.gz" - partition = "all" - - cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") - - tokenizer = Tokenizer() - - new_cuts = [] - i = 0 - for cut in cut_set: - # Each cut only contains one supervision - assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) - text = cut.supervisions[0].normalized_text - cut.tokens = tokenizer.text_to_tokens(text) - - new_cuts.append(cut) - - new_cut_set = CutSet.from_cuts(new_cuts) - new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - - prepare_tokens_baker_zh() diff --git a/egs/baker_zh/TTS/local/pypinyin-local.dict b/egs/baker_zh/TTS/local/pypinyin-local.dict deleted file mode 100644 index 5e386014c..000000000 --- a/egs/baker_zh/TTS/local/pypinyin-local.dict +++ /dev/null @@ -1,328 +0,0 @@ -姐姐 jie3 jie -宝宝 bao3 bao -哥哥 ge1 ge -妹妹 mei4 mei -弟弟 di4 di -妈妈 ma1 ma -开心哦 kai1 xin1 o -爸爸 ba4 ba -秘密哟 mi4 mi4 yo -哦 o -一年 yi4 nian2 -一夜 yi2 ye4 -一切 yi2 qie4 -一座 yi2 zuo4 -一下 yi2 xia4 -上一山 shang4 yi2 shan1 -下一山 xia4 yi2 shan1 -休息 xiu1 xi2 -东西 dong1 xi -上一届 shang4 yi2 jie4 -便宜 pian2 yi4 -加长 jia1 chang2 -单田芳 shan4 tian2 fang1 -帧 zhen1 -长时间 chang2 shi2 jian1 -长时 chang2 shi2 -识别 shi2 bie2 -生命中 sheng1 ming4 zhong1 -踏实 ta1 shi -嗯 en4 -溜达 liu1 da -少儿 shao4 er2 -爷爷 ye2 ye -不是 bu2 shi4 -一圈 yi1 quan1 -厜读一声 zui1 du2 yi4 sheng1 -一种 yi4 zhong3 -一簇簇 yi2 cu4 cu4 -一个 yi2 ge4 -一样 yi2 yang4 -一跩一跩 yi4 zhuai3 yi4 zhuai3 -一会儿 yi2 hui4 er -一幢 yi2 zhuang4 -挨了 ai2 le -熬菜 ao1 cai4 -扒鸡 pa2 ji1 -背枪 bei1 qiang1 -绷瓷儿 beng4 ci2 er2 -绷劲儿 beng3 jin4 er -绷着脸 beng3 zhe lian3 -藏医 zang4 yi1 -噌吰 cheng1 hong2 -差点儿 cha4 dian3 er -差失 cha1 shi1 -差误 cha1 wu4 -孱头 can4 tou -乘间 cheng2 jian4 -锄镰棘矜 chu2 lian2 ji2 qin2 -川藏 chuan1 zang4 -穿著 chuan1 zhuo2 -答讪 da1 shan4 -答言 da1 yan2 -大伯子 da4 bai3 zi -大夫 dai4 fu -弹冠 tan2 guan1 -当间 dang1 jian4 -当然咯 dang1 ran2 lo -点种 dian3 zhong3 -垛好 duo4 hao3 -发疟子 fa1 yao4 zi -饭熟了 fan4 shou2 le -附著 fu4 zhuo2 -复沓 fu4 ta4 -供稿 gong1 gao3 -供养 gong1 yang3 -骨朵 gu1 duo -骨碌 gu1 lu -果脯 guo3 fu3 -哈什玛 ha4 shi2 ma3 -海蜇 hai3 zhe2 -呵欠 he1 qian -河水汤汤 he2 shui3 shang1 shang1 -鹄立 hu2 li4 -鹄望 hu2 wang4 -混人 hun2 ren2 -混水 hun2 shui3 -鸡血 ji1 xie3 -缉鞋口 qi1 xie2 kou3 -亟来闻讯 qi4 lai2 wen2 xun4 -计量 ji4 liang2 -济水 ji3 shui3 -间杂 jian4 za2 -脚跐两只船 jiao3 ci3 liang3 zhi1 chuan2 -脚儿 jue2 er2 -口角 kou3 jiao3 -勒石 le4 shi2 -累进 lei3 jin4 -累累如丧家之犬 lei2 lei2 ru2 sang4 jia1 zhi1 quan3 -累年 lei3 nian2 -脸涨通红 lian3 zhang4 tong1 hong2 -踉锵 liang4 qiang1 -燎眉毛 liao3 mei2 mao2 -燎头发 liao3 tou2 fa4 -溜达 liu1 da -溜缝儿 liu4 feng4 er -馏口饭 liu4 kou3 fan4 -遛马 liu4 ma3 -遛鸟 liu4 niao3 -遛弯儿 liu4 wan1 er -楼枪机 lou1 qiang1 ji1 -搂钱 lou1 qian2 -鹿脯 lu4 fu3 -露头 lou4 tou2 -落魄 luo4 po4 -捋胡子 lv3 hu2 zi -绿地 lv4 di4 -麦垛 mai4 duo4 -没劲儿 mei2 jin4 er -闷棍 men4 gun4 -闷葫芦 men4 hu2 lu -闷头干 men1 tou2 gan4 -蒙古 meng3 gu3 -靡日不思 mi3 ri4 bu4 si1 -缪姓 miao4 xing4 -抹墙 mo4 qiang2 -抹下脸 ma1 xia4 lian3 -泥子 ni4 zi -拗不过 niu4 bu guo4 -排车 pai3 che1 -盘诘 pan2 jie2 -膀肿 pang1 zhong3 -炮干 bao1 gan1 -炮格 pao2 ge2 -碰钉子 peng4 ding1 zi -缥色 piao3 se4 -瀑河 bao4 he2 -蹊径 xi1 jing4 -前后相属 qian2 hou4 xiang1 zhu3 -翘尾巴 qiao4 wei3 ba -趄坡儿 qie4 po1 er -秦桧 qin2 hui4 -圈马 juan1 ma3 -雀盲眼 qiao3 mang2 yan3 -雀子 qiao1 zi -三年五载 san1 nian2 wu3 zai3 -加载 jia1 zai3 -山大王 shan1 dai4 wang -苫屋草 shan4 wu1 cao3 -数数 shu3 shu4 -说客 shui4 ke4 -思量 si1 liang2 -伺侯 ci4 hou -踏实 ta1 shi -提溜 di1 liu -调拨 diao4 bo1 -帖子 tie3 zi -铜钿 tong2 tian2 -头昏脑涨 tou2 hun1 nao3 zhang4 -褪色 tui4 se4 -褪着手 tun4 zhe shou3 -圩子 wei2 zi -尾巴 wei3 ba -系好船只 xi4 hao3 chuan2 zhi1 -系好马匹 xi4 hao3 ma3 pi3 -杏脯 xing4 fu3 -姓单 xing4 shan4 -姓葛 xing4 ge3 -姓哈 xing4 ha3 -姓解 xing4 xie4 -姓秘 xing4 bi4 -姓宁 xing4 ning4 -旋风 xuan4 feng1 -旋根车轴 xuan4 gen1 che1 zhou2 -荨麻 qian2 ma2 -一幢楼房 yi1 zhuang4 lou2 fang2 -遗之千金 wei4 zhi1 qian1 jin1 -殷殷 yin3 yin3 -应招 ying4 zhao1 -用称约 yong4 cheng4 yao1 -约斤肉 yao1 jin1 rou4 -晕机 yun4 ji1 -熨贴 yu4 tie1 -咋办 za3 ban4 -咋呼 zha1 hu -仔兽 zi3 shou4 -扎彩 za1 cai3 -扎实 zha1 shi -扎腰带 za1 yao1 dai4 -轧朋友 ga2 peng2 you3 -爪子 zhua3 zi -折腾 zhe1 teng -着实 zhuo2 shi2 -着我旧时裳 zhuo2 wo3 jiu4 shi2 chang2 -枝蔓 zhi1 man4 -中鹄 zhong1 hu2 -中选 zhong4 xuan3 -猪圈 zhu1 juan4 -拽住不放 zhuai4 zhu4 bu4 fang4 -转悠 zhuan4 you -庄稼熟了 zhuang1 jia shou2 le -酌量 zhuo2 liang2 -罪行累累 zui4 xing2 lei3 lei3 -一手 yi4 shou3 -一去不复返 yi2 qu4 bu2 fu4 fan3 -一颗 yi4 ke1 -一件 yi2 jian4 -一斤 yi4 jin1 -一点 yi4 dian3 -一朵 yi4 duo3 -一声 yi4 sheng1 -一身 yi4 shen1 -不要 bu2 yao4 -一人 yi4 ren2 -一个 yi2 ge4 -一把 yi4 ba3 -一门 yi4 men2 -一門 yi4 men2 -一艘 yi4 sou1 -一片 yi2 pian4 -一篇 yi2 pian1 -一份 yi2 fen4 -好嗲 hao3 dia3 -随地 sui2 di4 -扁担长 bian3 dan4 chang3 -一堆 yi4 dui1 -不义 bu2 yi4 -放一放 fang4 yi2 fang4 -一米 yi4 mi3 -一顿 yi2 dun4 -一层楼 yi4 ceng2 lou2 -一条 yi4 tiao2 -一件 yi2 jian4 -一棵 yi4 ke1 -一小股 yi4 xiao3 gu3 -一拐一拐 yi4 guai3 yi4 guai3 -一根 yi4 gen1 -沆瀣一气 hang4 xie4 yi2 qi4 -一丝 yi4 si1 -一毫 yi4 hao2 -一樣 yi2 yang4 -处处 chu4 chu4 -一餐 yi4 can -永不 yong3 bu2 -一看 yi2 kan4 -一架 yi2 jia4 -送还 song4 huan2 -一见 yi2 jian4 -一座 yi2 zuo4 -一块 yi2 kuai4 -一天 yi4 tian1 -一只 yi4 zhi1 -一支 yi4 zhi1 -一字 yi2 zi4 -一句 yi2 ju4 -一张 yi4 zhang1 -一條 yi4 tiao2 -一场 yi4 chang3 -一粒 yi2 li4 -小俩口 xiao3 liang3 kou3 -一首 yi4 shou3 -一对 yi2 dui4 -一手 yi4 shou3 -又一村 you4 yi4 cun1 -一概而论 yi2 gai4 er2 lun4 -一峰峰 yi4 feng1 feng1 -不但 bu2 dan4 -一笑 yi2 xiao4 -挠痒痒 nao2 yang3 yang -不对 bu2 dui4 -拧开 ning3 kai1 -爱不释手 ai4 bu2 shi4 shou3 -一念 yi2 nian4 -夺得 duo2 de2 -一袭 yi4 xi2 -一定 yi2 ding4 -不慎 bu2 shen4 -剽窃 piao2 qie4 -一时 yi4 shi2 -撇开 pie3 kai1 -一祭 yi2 ji4 -发卡 fa4 qia3 -少不了 shao3 bu4 liao3 -千虑一失 qian1 lv4 yi4 shi1 -呛得 qiang4 de2 -切菜 qie1 cai4 -茄盒 qie2 he2 -不去 bu2 qu4 -一大圈 yi2 da4 quan1 -不再 bu2 zai4 -一群 yi4 qun2 -不必 bu2 bi4 -一些 yi4 xie1 -一路 yi2 lu4 -一股 yi4 gu3 -一到 yi2 dao4 -一拨 yi4 bo1 -一排 yi4 pai2 -一空 yi4 kong1 -吮吸着 shun3 xi1 zhe -不适合 bu2 shi4 he2 -一串串 yi2 chuan4 chuan4 -一提起 yi4 ti2 qi3 -一尘不染 yi4 chen2 bu4 ran3 -一生 yi4 sheng1 -一派 yi2 pai4 -不断 bu2 duan4 -一次 yi2 ci4 -不进步 bu2 jin4 bu4 -娃娃 wa2 wa -万户侯 wan4 hu4 hou2 -一方 yi4 fang1 -一番话 yi4 fan1 hua4 -一遍 yi2 bian4 -不计较 bu2 ji4 jiao4 -诇 xiong4 -一边 yi4 bian1 -一束 yi2 shu4 -一听到 yi4 ting1 dao4 -炸鸡 zha2 ji1 -乍暧还寒 zha4 ai4 huan2 han2 -我说诶 wo3 shuo1 ei1 -棒诶 bang4 ei1 -寒碜 han2 chen4 -应采儿 ying4 cai3 er2 -晕车 yun1 che1 -必应 bi4 ying4 -应援 ying4 yuan2 -应力 ying4 li4 \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/symbols.py b/egs/baker_zh/TTS/local/symbols.py deleted file mode 100644 index 1e6878870..000000000 --- a/egs/baker_zh/TTS/local/symbols.py +++ /dev/null @@ -1,73 +0,0 @@ -# This file is copied from -# https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py -_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"] - -_initials = [ - "^", - "b", - "c", - "ch", - "d", - "f", - "g", - "h", - "j", - "k", - "l", - "m", - "n", - "p", - "q", - "r", - "s", - "sh", - "t", - "x", - "z", - "zh", -] - -_tones = ["1", "2", "3", "4", "5"] - -_finals = [ - "a", - "ai", - "an", - "ang", - "ao", - "e", - "ei", - "en", - "eng", - "er", - "i", - "ia", - "ian", - "iang", - "iao", - "ie", - "ii", - "iii", - "in", - "ing", - "iong", - "iou", - "o", - "ong", - "ou", - "u", - "ua", - "uai", - "uan", - "uang", - "uei", - "uen", - "ueng", - "uo", - "v", - "van", - "ve", - "vn", -] - -symbols = _pause + _initials + [i + j for i in _finals for j in _tones] diff --git a/egs/baker_zh/TTS/local/tokenizer.py b/egs/baker_zh/TTS/local/tokenizer.py deleted file mode 100644 index cbf6c9c77..000000000 --- a/egs/baker_zh/TTS/local/tokenizer.py +++ /dev/null @@ -1,137 +0,0 @@ -# This file is modified from -# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py - -import logging -from pathlib import Path -from typing import List - -# Note pinyin_dict is from ./pinyin_dict.py -from pinyin_dict import pinyin_dict -from pypinyin import Style -from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin -from pypinyin.converter import DefaultConverter -from pypinyin.core import Pinyin, load_phrases_dict - - -class _MyConverter(NeutralToneWith5Mixin, DefaultConverter): - pass - - -class Tokenizer: - def __init__(self, tokens: str = ""): - self._load_pinyin_dict() - self._pinyin_parser = Pinyin(_MyConverter()) - - if tokens != "": - self._load_tokens(tokens) - - def texts_to_token_ids(self, texts: List[str], **kwargs) -> List[List[int]]: - """ - Args: - texts: - A list of sentences. - kwargs: - Not used. It is for compatibility with other TTS recipes in icefall. - """ - tokens = [] - - for text in texts: - tokens.append(self.text_to_tokens(text)) - - return self.tokens_to_token_ids(tokens) - - def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]: - ans = [] - - for token_list in tokens: - token_ids = [] - for t in token_list: - if t not in self.token2id: - logging.warning(f"Skip OOV {t}") - continue - token_ids.append(self.token2id[t]) - ans.append(token_ids) - - return ans - - def text_to_tokens(self, text: str) -> List[str]: - # Convert "," to ["sp", "sil"] - # Convert "。" to ["sil"] - # append ["eos"] at the end of a sentence - phonemes = ["sil"] - pinyins = self._pinyin_parser.pinyin( - text, - style=Style.TONE3, - errors=lambda x: [[w] for w in x], - ) - - new_pinyin = [] - for p in pinyins: - p = p[0] - if p == ",": - new_pinyin.extend(["sp", "sil"]) - elif p == "。": - new_pinyin.append("sil") - else: - new_pinyin.append(p) - sub_phonemes = self._get_phoneme4pinyin(new_pinyin) - sub_phonemes.append("eos") - phonemes.extend(sub_phonemes) - return phonemes - - def _get_phoneme4pinyin(self, pinyins): - result = [] - for pinyin in pinyins: - if pinyin in ("sil", "sp"): - result.append(pinyin) - elif pinyin[:-1] in pinyin_dict: - tone = pinyin[-1] - a = pinyin[:-1] - a1, a2 = pinyin_dict[a] - # every word is appended with a #0 - result += [a1, a2 + tone, "#0"] - - return result - - def _load_pinyin_dict(self): - this_dir = Path(__file__).parent.resolve() - my_dict = {} - with open(f"{this_dir}/pypinyin-local.dict", "r", encoding="utf-8") as f: - content = f.readlines() - for line in content: - cuts = line.strip().split() - hanzi = cuts[0] - pinyin = cuts[1:] - my_dict[hanzi] = [[p] for p in pinyin] - - load_phrases_dict(my_dict) - - def _load_tokens(self, filename): - token2id: Dict[str, int] = {} - - with open(filename, "r", encoding="utf-8") as f: - for line in f.readlines(): - info = line.rstrip().split() - if len(info) == 1: - # case of space - token = " " - idx = int(info[0]) - else: - token, idx = info[0], int(info[1]) - - assert token not in token2id, token - - token2id[token] = idx - - self.token2id = token2id - self.vocab_size = len(self.token2id) - self.pad_id = self.token2id["#0"] - - -def main(): - tokenizer = Tokenizer() - tokenizer._sentence_to_ids("你好,好的。") - - -if __name__ == "__main__": - main() diff --git a/egs/baker_zh/TTS/local/validate_manifest.py b/egs/baker_zh/TTS/local/validate_manifest.py deleted file mode 120000 index b4d52ebca..000000000 --- a/egs/baker_zh/TTS/local/validate_manifest.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/local/validate_manifest.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/prepare.sh b/egs/baker_zh/TTS/prepare.sh deleted file mode 100755 index 6fa87fe43..000000000 --- a/egs/baker_zh/TTS/prepare.sh +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env bash - -# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 -export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python - -set -eou pipefail - -stage=-1 -stop_stage=100 - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: build monotonic_align lib" - if [ ! -d vits/monotonic_align/build ]; then - cd vits/monotonic_align - python3 setup.py build_ext --inplace - cd ../../ - else - log "monotonic_align lib already built" - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Download data" - - # The directory $dl_dir/BZNSYP will contain 3 sub directories: - # - PhoneLabeling - # - ProsodyLabeling - # - Wave - - # If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink - # - # ln -sfv /path/to/BZNSYP $dl_dir/ - # touch $dl_dir/BZNSYP/.completed - # - if [ ! -d $dl_dir/BZNSYP ]; then - lhotse download baker-zh $dl_dir - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare baker-zh manifest" - # We assume that you have downloaded the baker corpus - # to $dl_dir/BZNSYP - mkdir -p data/manifests - if [ ! -e data/manifests/.baker.done ]; then - lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests - touch data/manifests/.baker.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Compute spectrogram for baker (may take 3 minutes)" - mkdir -p data/spectrogram - if [ ! -e data/spectrogram/.baker.done ]; then - ./local/compute_spectrogram_baker.py - touch data/spectrogram/.baker.done - fi - - if [ ! -e data/spectrogram/.baker-validated.done ]; then - log "Validating data/spectrogram for baker" - python3 ./local/validate_manifest.py \ - data/spectrogram/baker_zh_cuts_all.jsonl.gz - touch data/spectrogram/.baker-validated.done - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare tokens for baker-zh (may take 20 seconds)" - if [ ! -e data/spectrogram/.baker_zh_with_token.done ]; then - - ./local/prepare_tokens_baker_zh.py - - mv -v data/spectrogram/baker_zh_cuts_with_tokens_all.jsonl.gz \ - data/spectrogram/baker_zh_cuts_all.jsonl.gz - - touch data/spectrogram/.baker_zh_with_token.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Split the baker-zh cuts into train, valid and test sets (may take 25 seconds)" - if [ ! -e data/spectrogram/.baker_zh_split.done ]; then - lhotse subset --last 600 \ - data/spectrogram/baker_zh_cuts_all.jsonl.gz \ - data/spectrogram/baker_zh_cuts_validtest.jsonl.gz - lhotse subset --first 100 \ - data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \ - data/spectrogram/baker_zh_cuts_valid.jsonl.gz - lhotse subset --last 500 \ - data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \ - data/spectrogram/baker_zh_cuts_test.jsonl.gz - - rm data/spectrogram/baker_zh_cuts_validtest.jsonl.gz - - n=$(( $(gunzip -c data/spectrogram/baker_zh_cuts_all.jsonl.gz | wc -l) - 600 )) - lhotse subset --first $n \ - data/spectrogram/baker_zh_cuts_all.jsonl.gz \ - data/spectrogram/baker_zh_cuts_train.jsonl.gz - touch data/spectrogram/.baker_zh_split.done - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Generate token file" - if [ ! -e data/tokens.txt ]; then - ./local/prepare_token_file.py --tokens data/tokens.txt - fi -fi diff --git a/egs/baker_zh/TTS/shared b/egs/baker_zh/TTS/shared deleted file mode 120000 index 4cbd91a7e..000000000 --- a/egs/baker_zh/TTS/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/duration_predictor.py b/egs/baker_zh/TTS/vits/duration_predictor.py deleted file mode 120000 index 9972b476f..000000000 --- a/egs/baker_zh/TTS/vits/duration_predictor.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/duration_predictor.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/export-onnx.py b/egs/baker_zh/TTS/vits/export-onnx.py deleted file mode 100755 index 11c8a9791..000000000 --- a/egs/baker_zh/TTS/vits/export-onnx.py +++ /dev/null @@ -1,414 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script exports a VITS model from PyTorch to ONNX. - -Export the model to ONNX: -./vits/export-onnx.py \ - --epoch 1000 \ - --exp-dir vits/exp \ - --tokens data/tokens.txt - -It will generate one file inside vits/exp: - - vits-epoch-1000.onnx - -See ./test_onnx.py for how to use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import onnx -import torch -import torch.nn as nn -from tokenizer import Tokenizer -from train import get_model, get_params - -from icefall.checkpoint import load_checkpoint - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=1000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="vits/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--model-type", - type=str, - default="high", - choices=["low", "medium", "high"], - help="""If not empty, valid values are: low, medium, high. - It controls the model size. low -> runs faster. - """, - ) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = str(value) - - onnx.save(model, filename) - - -class OnnxModel(nn.Module): - """A wrapper for VITS generator.""" - - def __init__(self, model: nn.Module): - """ - Args: - model: - A VITS generator. - frame_shift: - The frame shift in samples. - """ - super().__init__() - self.model = model - - def forward( - self, - tokens: torch.Tensor, - tokens_lens: torch.Tensor, - noise_scale: float = 0.667, - alpha: float = 1.0, - noise_scale_dur: float = 0.8, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of VITS.inference_batch - - Args: - tokens: - Input text token indexes (1, T_text) - tokens_lens: - Number of tokens of shape (1,) - noise_scale (float): - Noise scale parameter for flow. - noise_scale_dur (float): - Noise scale parameter for duration predictor. - alpha (float): - Alpha parameter to control the speed of generated speech. - - Returns: - Return a tuple containing: - - audio, generated wavform tensor, (B, T_wav) - """ - audio, _, _ = self.model.generator.inference( - text=tokens, - text_lengths=tokens_lens, - noise_scale=noise_scale, - noise_scale_dur=noise_scale_dur, - alpha=alpha, - ) - return audio - - -def export_model_onnx( - model: nn.Module, - model_filename: str, - vocab_size: int, - opset_version: int = 11, -) -> None: - """Export the given generator model to ONNX format. - The exported model has one input: - - - tokens, a tensor of shape (1, T_text); dtype is torch.int64 - - and it has one output: - - - audio, a tensor of shape (1, T'); dtype is torch.float32 - - Args: - model: - The VITS generator. - model_filename: - The filename to save the exported ONNX model. - vocab_size: - Number of tokens used in training. - opset_version: - The opset version to use. - """ - tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64) - tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) - noise_scale = torch.tensor([1], dtype=torch.float32) - noise_scale_dur = torch.tensor([1], dtype=torch.float32) - alpha = torch.tensor([1], dtype=torch.float32) - - torch.onnx.export( - model, - (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur), - model_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "tokens", - "tokens_lens", - "noise_scale", - "alpha", - "noise_scale_dur", - ], - output_names=["audio"], - dynamic_axes={ - "tokens": {0: "N", 1: "T"}, - "tokens_lens": {0: "N"}, - "audio": {0: "N", 1: "T"}, - }, - ) - - if model.model.spks is None: - num_speakers = 1 - else: - num_speakers = model.model.spks - - meta_data = { - "model_type": "vits", - "version": "1", - "model_author": "k2-fsa", - "comment": "icefall", # must be icefall for models from icefall - "language": "Chinese", - "n_speakers": num_speakers, - "sample_rate": model.model.sampling_rate, # Must match the real sample rate - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=model_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - - model.to("cpu") - model.eval() - - model = OnnxModel(model=model) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M") - - suffix = f"epoch-{params.epoch}" - - opset_version = 13 - - logging.info("Exporting encoder") - model_filename = params.exp_dir / f"vits-{suffix}.onnx" - export_model_onnx( - model, - model_filename, - params.vocab_size, - opset_version=opset_version, - ) - logging.info(f"Exported generator to {model_filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() - -""" -Supported languages. - -LJSpeech is using "en-us" from the second column. - -Pty Language Age/Gender VoiceName File Other Languages - 5 af --/M Afrikaans gmw/af - 5 am --/M Amharic sem/am - 5 an --/M Aragonese roa/an - 5 ar --/M Arabic sem/ar - 5 as --/M Assamese inc/as - 5 az --/M Azerbaijani trk/az - 5 ba --/M Bashkir trk/ba - 5 be --/M Belarusian zle/be - 5 bg --/M Bulgarian zls/bg - 5 bn --/M Bengali inc/bn - 5 bpy --/M Bishnupriya_Manipuri inc/bpy - 5 bs --/M Bosnian zls/bs - 5 ca --/M Catalan roa/ca - 5 chr-US-Qaaa-x-west --/M Cherokee_ iro/chr - 5 cmn --/M Chinese_(Mandarin,_latin_as_English) sit/cmn (zh-cmn 5)(zh 5) - 5 cmn-latn-pinyin --/M Chinese_(Mandarin,_latin_as_Pinyin) sit/cmn-Latn-pinyin (zh-cmn 5)(zh 5) - 5 cs --/M Czech zlw/cs - 5 cv --/M Chuvash trk/cv - 5 cy --/M Welsh cel/cy - 5 da --/M Danish gmq/da - 5 de --/M German gmw/de - 5 el --/M Greek grk/el - 5 en-029 --/M English_(Caribbean) gmw/en-029 (en 10) - 2 en-gb --/M English_(Great_Britain) gmw/en (en 2) - 5 en-gb-scotland --/M English_(Scotland) gmw/en-GB-scotland (en 4) - 5 en-gb-x-gbclan --/M English_(Lancaster) gmw/en-GB-x-gbclan (en-gb 3)(en 5) - 5 en-gb-x-gbcwmd --/M English_(West_Midlands) gmw/en-GB-x-gbcwmd (en-gb 9)(en 9) - 5 en-gb-x-rp --/M English_(Received_Pronunciation) gmw/en-GB-x-rp (en-gb 4)(en 5) - 2 en-us --/M English_(America) gmw/en-US (en 3) - 5 en-us-nyc --/M English_(America,_New_York_City) gmw/en-US-nyc - 5 eo --/M Esperanto art/eo - 5 es --/M Spanish_(Spain) roa/es - 5 es-419 --/M Spanish_(Latin_America) roa/es-419 (es-mx 6) - 5 et --/M Estonian urj/et - 5 eu --/M Basque eu - 5 fa --/M Persian ira/fa - 5 fa-latn --/M Persian_(Pinglish) ira/fa-Latn - 5 fi --/M Finnish urj/fi - 5 fr-be --/M French_(Belgium) roa/fr-BE (fr 8) - 5 fr-ch --/M French_(Switzerland) roa/fr-CH (fr 8) - 5 fr-fr --/M French_(France) roa/fr (fr 5) - 5 ga --/M Gaelic_(Irish) cel/ga - 5 gd --/M Gaelic_(Scottish) cel/gd - 5 gn --/M Guarani sai/gn - 5 grc --/M Greek_(Ancient) grk/grc - 5 gu --/M Gujarati inc/gu - 5 hak --/M Hakka_Chinese sit/hak - 5 haw --/M Hawaiian map/haw - 5 he --/M Hebrew sem/he - 5 hi --/M Hindi inc/hi - 5 hr --/M Croatian zls/hr (hbs 5) - 5 ht --/M Haitian_Creole roa/ht - 5 hu --/M Hungarian urj/hu - 5 hy --/M Armenian_(East_Armenia) ine/hy (hy-arevela 5) - 5 hyw --/M Armenian_(West_Armenia) ine/hyw (hy-arevmda 5)(hy 8) - 5 ia --/M Interlingua art/ia - 5 id --/M Indonesian poz/id - 5 io --/M Ido art/io - 5 is --/M Icelandic gmq/is - 5 it --/M Italian roa/it - 5 ja --/M Japanese jpx/ja - 5 jbo --/M Lojban art/jbo - 5 ka --/M Georgian ccs/ka - 5 kk --/M Kazakh trk/kk - 5 kl --/M Greenlandic esx/kl - 5 kn --/M Kannada dra/kn - 5 ko --/M Korean ko - 5 kok --/M Konkani inc/kok - 5 ku --/M Kurdish ira/ku - 5 ky --/M Kyrgyz trk/ky - 5 la --/M Latin itc/la - 5 lb --/M Luxembourgish gmw/lb - 5 lfn --/M Lingua_Franca_Nova art/lfn - 5 lt --/M Lithuanian bat/lt - 5 ltg --/M Latgalian bat/ltg - 5 lv --/M Latvian bat/lv - 5 mi --/M Māori poz/mi - 5 mk --/M Macedonian zls/mk - 5 ml --/M Malayalam dra/ml - 5 mr --/M Marathi inc/mr - 5 ms --/M Malay poz/ms - 5 mt --/M Maltese sem/mt - 5 mto --/M Totontepec_Mixe miz/mto - 5 my --/M Myanmar_(Burmese) sit/my - 5 nb --/M Norwegian_Bokmål gmq/nb (no 5) - 5 nci --/M Nahuatl_(Classical) azc/nci - 5 ne --/M Nepali inc/ne - 5 nl --/M Dutch gmw/nl - 5 nog --/M Nogai trk/nog - 5 om --/M Oromo cus/om - 5 or --/M Oriya inc/or - 5 pa --/M Punjabi inc/pa - 5 pap --/M Papiamento roa/pap - 5 piqd --/M Klingon art/piqd - 5 pl --/M Polish zlw/pl - 5 pt --/M Portuguese_(Portugal) roa/pt (pt-pt 5) - 5 pt-br --/M Portuguese_(Brazil) roa/pt-BR (pt 6) - 5 py --/M Pyash art/py - 5 qdb --/M Lang_Belta art/qdb - 5 qu --/M Quechua qu - 5 quc --/M K'iche' myn/quc - 5 qya --/M Quenya art/qya - 5 ro --/M Romanian roa/ro - 5 ru --/M Russian zle/ru - 5 ru-cl --/M Russian_(Classic) zle/ru-cl - 2 ru-lv --/M Russian_(Latvia) zle/ru-LV - 5 sd --/M Sindhi inc/sd - 5 shn --/M Shan_(Tai_Yai) tai/shn - 5 si --/M Sinhala inc/si - 5 sjn --/M Sindarin art/sjn - 5 sk --/M Slovak zlw/sk - 5 sl --/M Slovenian zls/sl - 5 smj --/M Lule_Saami urj/smj - 5 sq --/M Albanian ine/sq - 5 sr --/M Serbian zls/sr - 5 sv --/M Swedish gmq/sv - 5 sw --/M Swahili bnt/sw - 5 ta --/M Tamil dra/ta - 5 te --/M Telugu dra/te - 5 th --/M Thai tai/th - 5 tk --/M Turkmen trk/tk - 5 tn --/M Setswana bnt/tn - 5 tr --/M Turkish trk/tr - 5 tt --/M Tatar trk/tt - 5 ug --/M Uyghur trk/ug - 5 uk --/M Ukrainian zle/uk - 5 ur --/M Urdu inc/ur - 5 uz --/M Uzbek trk/uz - 5 vi --/M Vietnamese_(Northern) aav/vi - 5 vi-vn-x-central --/M Vietnamese_(Central) aav/vi-VN-x-central - 5 vi-vn-x-south --/M Vietnamese_(Southern) aav/vi-VN-x-south - 5 yue --/M Chinese_(Cantonese) sit/yue (zh-yue 5)(zh 8) - 5 yue --/M Chinese_(Cantonese,_latin_as_Jyutping) sit/yue-Latn-jyutping (zh-yue 5)(zh 8) -""" diff --git a/egs/baker_zh/TTS/vits/flow.py b/egs/baker_zh/TTS/vits/flow.py deleted file mode 120000 index e65d91ea7..000000000 --- a/egs/baker_zh/TTS/vits/flow.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/generate_lexicon.py b/egs/baker_zh/TTS/vits/generate_lexicon.py deleted file mode 100755 index 6d040ef53..000000000 --- a/egs/baker_zh/TTS/vits/generate_lexicon.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 - -from pypinyin import phrases_dict, pinyin_dict -from tokenizer import Tokenizer - - -def main(): - filename = "lexicon.txt" - tokens = "./data/tokens.txt" - tokenizer = Tokenizer(tokens) - - word_dict = pinyin_dict.pinyin_dict - phrases = phrases_dict.phrases_dict - - i = 0 - with open(filename, "w", encoding="utf-8") as f: - for key in word_dict: - if not (0x4E00 <= key <= 0x9FFF): - continue - - w = chr(key) - - # 1 to remove the initial sil - # :-1 to remove the final eos - tokens = tokenizer.text_to_tokens(w)[1:-1] - - tokens = " ".join(tokens) - f.write(f"{w} {tokens}\n") - - for key in phrases: - # 1 to remove the initial sil - # :-1 to remove the final eos - tokens = tokenizer.text_to_tokens(key)[1:-1] - tokens = " ".join(tokens) - f.write(f"{key} {tokens}\n") - - -if __name__ == "__main__": - main() diff --git a/egs/baker_zh/TTS/vits/generator.py b/egs/baker_zh/TTS/vits/generator.py deleted file mode 120000 index 611679bfa..000000000 --- a/egs/baker_zh/TTS/vits/generator.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/generator.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/hifigan.py b/egs/baker_zh/TTS/vits/hifigan.py deleted file mode 120000 index 5ac025de7..000000000 --- a/egs/baker_zh/TTS/vits/hifigan.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/hifigan.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/loss.py b/egs/baker_zh/TTS/vits/loss.py deleted file mode 120000 index 672e5ff68..000000000 --- a/egs/baker_zh/TTS/vits/loss.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/monotonic_align b/egs/baker_zh/TTS/vits/monotonic_align deleted file mode 120000 index 71934e7cc..000000000 --- a/egs/baker_zh/TTS/vits/monotonic_align +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/monotonic_align \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/pinyin_dict.py b/egs/baker_zh/TTS/vits/pinyin_dict.py deleted file mode 120000 index b8683bd2d..000000000 --- a/egs/baker_zh/TTS/vits/pinyin_dict.py +++ /dev/null @@ -1 +0,0 @@ -../local/pinyin_dict.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/posterior_encoder.py b/egs/baker_zh/TTS/vits/posterior_encoder.py deleted file mode 120000 index 41d64a3a6..000000000 --- a/egs/baker_zh/TTS/vits/posterior_encoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/posterior_encoder.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/pypinyin-local.dict b/egs/baker_zh/TTS/vits/pypinyin-local.dict deleted file mode 120000 index 5bc9b7728..000000000 --- a/egs/baker_zh/TTS/vits/pypinyin-local.dict +++ /dev/null @@ -1 +0,0 @@ -../local/pypinyin-local.dict \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/residual_coupling.py b/egs/baker_zh/TTS/vits/residual_coupling.py deleted file mode 120000 index f979adbf0..000000000 --- a/egs/baker_zh/TTS/vits/residual_coupling.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/residual_coupling.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/test_onnx.py b/egs/baker_zh/TTS/vits/test_onnx.py deleted file mode 100755 index 66c94270c..000000000 --- a/egs/baker_zh/TTS/vits/test_onnx.py +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script is used to test the exported onnx model by vits/export-onnx.py - -Use the onnx model to generate a wav: -./vits/test_onnx.py \ - --model-filename vits/exp/vits-epoch-1000.onnx \ - --tokens data/tokens.txt -""" - - -import argparse -import logging - -import onnxruntime as ort -import torch -import torchaudio -from tokenizer import Tokenizer - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--model-filename", - type=str, - required=True, - help="Path to the onnx model.", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--text", - type=str, - default="Ask not what your country can do for you; ask what you can do for your country.", - help="Text to generate speech for", - ) - - parser.add_argument( - "--output-filename", - type=str, - default="test_onnx.wav", - help="Filename to save the generated wave file.", - ) - - return parser - - -class OnnxModel: - def __init__(self, model_filename: str): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.model = ort.InferenceSession( - model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") - - metadata = self.model.get_modelmeta().custom_metadata_map - self.sample_rate = int(metadata["sample_rate"]) - - def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: - """ - Args: - tokens: - A 1-D tensor of shape (1, T) - Returns: - A tensor of shape (1, T') - """ - noise_scale = torch.tensor([0.667], dtype=torch.float32) - noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) - alpha = torch.tensor([1.0], dtype=torch.float32) - - out = self.model.run( - [ - self.model.get_outputs()[0].name, - ], - { - self.model.get_inputs()[0].name: tokens.numpy(), - self.model.get_inputs()[1].name: tokens_lens.numpy(), - self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: alpha.numpy(), - self.model.get_inputs()[4].name: noise_scale_dur.numpy(), - }, - )[0] - return torch.from_numpy(out) - - -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - tokenizer = Tokenizer(args.tokens) - - logging.info("About to create onnx model") - model = OnnxModel(args.model_filename) - - text = args.text - tokens = tokenizer.texts_to_token_ids([text]) - tokens = torch.tensor(tokens) # (1, T) - tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) - audio = model(tokens, tokens_lens) # (1, T') - - output_filename = args.output_filename - torchaudio.save(output_filename, audio, sample_rate=model.sample_rate) - logging.info(f"Saved to {output_filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/baker_zh/TTS/vits/text_encoder.py b/egs/baker_zh/TTS/vits/text_encoder.py deleted file mode 120000 index 0efba277e..000000000 --- a/egs/baker_zh/TTS/vits/text_encoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/text_encoder.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/tokenizer.py b/egs/baker_zh/TTS/vits/tokenizer.py deleted file mode 120000 index 0368e07d3..000000000 --- a/egs/baker_zh/TTS/vits/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../local/tokenizer.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/train.py b/egs/baker_zh/TTS/vits/train.py deleted file mode 100755 index 694129a89..000000000 --- a/egs/baker_zh/TTS/vits/train.py +++ /dev/null @@ -1,927 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import numpy as np -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from lhotse.cut import Cut -from lhotse.utils import fix_random_seed -from tokenizer import Tokenizer -from torch.cuda.amp import GradScaler, autocast -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import BakerZhSpeechTtsDataModule -from utils import MetricsTracker, plot_feature, save_checkpoint -from vits import VITS - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, setup_logger, str2bool - -LRSchedulerType = torch.optim.lr_scheduler._LRScheduler - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=1000, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="vits/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/tokens.txt", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--lr", type=float, default=2.0e-4, help="The base learning rate." - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=20, - help="""Save checkpoint after processing this number of epochs" - periodically. We save checkpoint to exp-dir/ whenever - params.cur_epoch % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. - Since it will take around 1000 epochs, we suggest using a large - save_every_n to save disk space. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--model-type", - type=str, - default="high", - choices=["low", "medium", "high"], - help="""If not empty, valid values are: low, medium, high. - It controls the model size. low -> runs faster. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - """ - params = AttributeDict( - { - # training params - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": -1, # 0 - "log_interval": 50, - "valid_interval": 200, - "env_info": get_env_info(), - "sampling_rate": 48000, - "frame_shift": 256, - "frame_length": 1024, - "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length - "n_mels": 80, - "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss - "lambda_mel": 45.0, # loss scaling coefficient for Mel loss - "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss - "lambda_dur": 1.0, # loss scaling coefficient for duration loss - "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, model: nn.Module -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint(filename, model=model) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def get_model(params: AttributeDict) -> nn.Module: - mel_loss_params = { - "n_mels": params.n_mels, - "frame_length": params.frame_length, - "frame_shift": params.frame_shift, - } - model = VITS( - vocab_size=params.vocab_size, - feature_dim=params.feature_dim, - sampling_rate=params.sampling_rate, - model_type=params.model_type, - mel_loss_params=mel_loss_params, - lambda_adv=params.lambda_adv, - lambda_mel=params.lambda_mel, - lambda_feat_match=params.lambda_feat_match, - lambda_dur=params.lambda_dur, - lambda_kl=params.lambda_kl, - ) - return model - - -def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): - """Parse batch data""" - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - tokens = batch["tokens"] - - tokens = tokenizer.tokens_to_token_ids(tokens) - tokens = k2.RaggedTensor(tokens) - row_splits = tokens.shape.row_splits(1) - tokens_lens = row_splits[1:] - row_splits[:-1] - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) - # a tensor of shape (B, T) - tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) - - return audio, audio_lens, features, features_lens, tokens, tokens_lens - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - optimizer_g: Optimizer, - optimizer_d: Optimizer, - scheduler_g: LRSchedulerType, - scheduler_d: LRSchedulerType, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - tokenizer: - Used to convert text to phonemes. - optimizer_g: - The optimizer for generator. - optimizer_d: - The optimizer for discriminator. - scheduler_g: - The learning rate scheduler for generator, we call step() every epoch. - scheduler_d: - The learning rate scheduler for discriminator, we call step() every epoch. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to track the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - - def save_bad_model(suffix: str = ""): - save_checkpoint( - filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", - model=model, - params=params, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=0, - ) - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - - batch_size = len(batch["tokens"]) - audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( - batch, tokenizer, device - ) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - try: - with autocast(enabled=params.use_fp16): - # forward discriminator - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=False, - ) - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - # update discriminator - optimizer_d.zero_grad() - scaler.scale(loss_d).backward() - scaler.step(optimizer_d) - - with autocast(enabled=params.use_fp16): - # forward generator - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - return_sample=params.batch_idx_train % params.log_interval == 0, - ) - for k, v in stats_g.items(): - if "returned_sample" not in k: - loss_info[k] = v * batch_size - # update generator - optimizer_g.zero_grad() - scaler.scale(loss_g).backward() - scaler.step(optimizer_g) - scaler.update() - - # summary stats - tot_loss = tot_loss + loss_info - except: # noqa - save_bad_model() - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if params.batch_idx_train % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - - if cur_grad_scale < 8.0 or ( - cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 - ): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - if not saved_bad_model: - save_bad_model(suffix="-first-warning") - saved_bad_model = True - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - save_bad_model() - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if params.batch_idx_train % params.log_interval == 0: - cur_lr_g = max(scheduler_g.get_last_lr()) - cur_lr_d = max(scheduler_d.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " - f"loss[{loss_info}], tot_loss[{tot_loss}], " - f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate_g", cur_lr_g, params.batch_idx_train - ) - tb_writer.add_scalar( - "train/learning_rate_d", cur_lr_d, params.batch_idx_train - ) - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - if "returned_sample" in stats_g: - speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] - tb_writer.add_audio( - "train/speech_hat_", - speech_hat_, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/speech_", - speech_, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_image( - "train/mel_hat_", - plot_feature(mel_hat_), - params.batch_idx_train, - dataformats="HWC", - ) - tb_writer.add_image( - "train/mel_", - plot_feature(mel_), - params.batch_idx_train, - dataformats="HWC", - ) - - if ( - params.batch_idx_train % params.valid_interval == 0 - and not params.print_diagnostics - ): - logging.info("Computing validation loss") - valid_info, (speech_hat, speech) = compute_validation_loss( - params=params, - model=model, - tokenizer=tokenizer, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - tb_writer.add_audio( - "train/valdi_speech_hat", - speech_hat, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/valdi_speech", - speech, - params.batch_idx_train, - params.sampling_rate, - ) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - tokenizer: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, - rank: int = 0, -) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: - """Run the validation process.""" - model.eval() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to summary the stats over iterations - tot_loss = MetricsTracker() - returned_sample = None - - with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): - batch_size = len(batch["tokens"]) - ( - audio, - audio_lens, - features, - features_lens, - tokens, - tokens_lens, - ) = prepare_input(batch, tokenizer, device) - - loss_info = MetricsTracker() - loss_info["samples"] = batch_size - - # forward discriminator - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=False, - ) - assert loss_d.requires_grad is False - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - - # forward generator - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - ) - assert loss_g.requires_grad is False - for k, v in stats_g.items(): - loss_info[k] = v * batch_size - - # summary stats - tot_loss = tot_loss + loss_info - - # infer for first batch: - if batch_idx == 0 and rank == 0: - inner_model = model.module if isinstance(model, DDP) else model - audio_pred, _, duration = inner_model.inference( - text=tokens[0, : tokens_lens[0].item()] - ) - audio_pred = audio_pred.data.cpu().numpy() - audio_len_pred = ( - (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() - ) - assert audio_len_pred == len(audio_pred), ( - audio_len_pred, - len(audio_pred), - ) - audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() - returned_sample = (audio_pred, audio_gt) - - if world_size > 1: - tot_loss.reduce(device) - - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss, returned_sample - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - tokenizer: Tokenizer, - optimizer_g: torch.optim.Optimizer, - optimizer_d: torch.optim.Optimizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input( - batch, tokenizer, device - ) - try: - # for discriminator - with autocast(enabled=params.use_fp16): - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=False, - ) - optimizer_d.zero_grad() - loss_d.backward() - # for generator - with autocast(enabled=params.use_fp16): - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - ) - optimizer_g.zero_grad() - loss_g.backward() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - generator = model.generator - discriminator = model.discriminator - - num_param_g = sum([p.numel() for p in generator.parameters()]) - logging.info(f"Number of parameters in generator: {num_param_g}") - num_param_d = sum([p.numel() for p in discriminator.parameters()]) - logging.info(f"Number of parameters in discriminator: {num_param_d}") - logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - optimizer_g = torch.optim.AdamW( - generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 - ) - optimizer_d = torch.optim.AdamW( - discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 - ) - - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) - - if checkpoints is not None: - # load state_dict for optimizers - if "optimizer_g" in checkpoints: - logging.info("Loading optimizer_g state dict") - optimizer_g.load_state_dict(checkpoints["optimizer_g"]) - if "optimizer_d" in checkpoints: - logging.info("Loading optimizer_d state dict") - optimizer_d.load_state_dict(checkpoints["optimizer_d"]) - - # load state_dict for schedulers - if "scheduler_g" in checkpoints: - logging.info("Loading scheduler_g state dict") - scheduler_g.load_state_dict(checkpoints["scheduler_g"]) - if "scheduler_d" in checkpoints: - logging.info("Loading scheduler_d state dict") - scheduler_d.load_state_dict(checkpoints["scheduler_d"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - baker_zh = BakerZhSpeechTtsDataModule(args) - - train_cuts = baker_zh.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) - return False - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = baker_zh.train_dataloaders(train_cuts) - - valid_cuts = baker_zh.valid_cuts() - valid_dl = baker_zh.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - tokenizer=tokenizer, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - logging.info(f"Start epoch {epoch}") - - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - params.cur_epoch = epoch - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - train_one_epoch( - params=params, - model=model, - tokenizer=tokenizer, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - if epoch % params.save_every_n == 0 or epoch == params.num_epochs: - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - # step per epoch - scheduler_g.step() - scheduler_d.step() - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - BakerZhSpeechTtsDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/baker_zh/TTS/vits/transform.py b/egs/baker_zh/TTS/vits/transform.py deleted file mode 120000 index 962647408..000000000 --- a/egs/baker_zh/TTS/vits/transform.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/tts_datamodule.py b/egs/baker_zh/TTS/vits/tts_datamodule.py deleted file mode 100644 index 96c542277..000000000 --- a/egs/baker_zh/TTS/vits/tts_datamodule.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, Optional - -import torch -from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutConcatenate, - CutMix, - DynamicBucketingSampler, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, - SpeechSynthesisDataset, -) -from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples - AudioSamples, - OnTheFlyFeatures, -) -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class BakerZhSpeechTtsDataModule: - """ - DataModule for tts experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in TTS tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - self.sampling_rate = 48000 - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="TTS data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/spectrogram"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=False, - help="When enabled, each batch will have the " - "field: batch['cut'] with the cuts that " - "were used to construct it.", - ) - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--input-strategy", - type=str, - default="PrecomputedFeatures", - help="AudioSamples or PrecomputedFeatures", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to create train dataset") - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - sampling_rate = self.sampling_rate - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - train = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - sampling_rate = self.sampling_rate - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - else: - validate = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - shuffle=False, - ) - logging.info("About to create valid dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - if self.args.on_the_fly_feats: - sampling_rate = self.sampling_rate - config = SpectrogramConfig( - sampling_rate=sampling_rate, - frame_length=1024 / sampling_rate, # (in second), - frame_shift=256 / sampling_rate, # (in second) - use_fft_mag=True, - ) - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), - return_cuts=self.args.return_cuts, - ) - else: - test = SpeechSynthesisDataset( - return_text=False, - return_tokens=True, - feature_input_strategy=eval(self.args.input_strategy)(), - return_cuts=self.args.return_cuts, - ) - test_sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - num_buckets=self.args.num_buckets, - shuffle=False, - ) - logging.info("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=test_sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy( - self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz" - ) - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get validation cuts") - return load_manifest_lazy( - self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy( - self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz" - ) diff --git a/egs/baker_zh/TTS/vits/utils.py b/egs/baker_zh/TTS/vits/utils.py deleted file mode 120000 index 085e764b4..000000000 --- a/egs/baker_zh/TTS/vits/utils.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/vits.py b/egs/baker_zh/TTS/vits/vits.py deleted file mode 120000 index 1f58cf6fe..000000000 --- a/egs/baker_zh/TTS/vits/vits.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/vits.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/vits/wavenet.py b/egs/baker_zh/TTS/vits/wavenet.py deleted file mode 120000 index 28f0a78ee..000000000 --- a/egs/baker_zh/TTS/vits/wavenet.py +++ /dev/null @@ -1 +0,0 @@ -../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file