remove baker-zh

This commit is contained in:
Fangjun Kuang 2024-04-06 21:51:09 +08:00
parent f9bd5ced9d
commit 35578f0593
34 changed files with 0 additions and 3178 deletions

View File

@ -1,7 +0,0 @@
# Introduction
[./symbols.py](./symbols.py) is copied from
https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py
[./pypinyin-local.dict](./pypinyin-local.dict) is copied from
https://github.com/UEhQZXI/vits_chinese/blob/master/misc/pypinyin-local.dict

View File

@ -1,106 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the baker_zh dataset.
It looks for manifests in the directory data/manifests.
The generated spectrogram features are saved in data/spectrogram.
"""
import logging
import os
from pathlib import Path
import torch
from lhotse import (
CutSet,
LilcomChunkyWriter,
Spectrogram,
SpectrogramConfig,
load_manifest,
)
from lhotse.audio import RecordingSet
from lhotse.supervision import SupervisionSet
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_spectrogram_baker_zh():
src_dir = Path("data/manifests")
output_dir = Path("data/spectrogram")
num_jobs = min(4, os.cpu_count())
sampling_rate = 48000
frame_length = 1024 / sampling_rate # (in second)
frame_shift = 256 / sampling_rate # (in second)
use_fft_mag = True
prefix = "baker_zh"
suffix = "jsonl.gz"
partition = "all"
recordings = load_manifest(
src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet
)
supervisions = load_manifest(
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
)
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=frame_length,
frame_shift=frame_shift,
use_fft_mag=use_fft_mag,
)
extractor = Spectrogram(config)
with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{cuts_filename} already exists - skipping.")
return
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=recordings, supervisions=supervisions
)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_spectrogram_baker_zh()

View File

@ -1,421 +0,0 @@
# This dict is copied from
# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py
pinyin_dict = {
"a": ("^", "a"),
"ai": ("^", "ai"),
"an": ("^", "an"),
"ang": ("^", "ang"),
"ao": ("^", "ao"),
"ba": ("b", "a"),
"bai": ("b", "ai"),
"ban": ("b", "an"),
"bang": ("b", "ang"),
"bao": ("b", "ao"),
"be": ("b", "e"),
"bei": ("b", "ei"),
"ben": ("b", "en"),
"beng": ("b", "eng"),
"bi": ("b", "i"),
"bian": ("b", "ian"),
"biao": ("b", "iao"),
"bie": ("b", "ie"),
"bin": ("b", "in"),
"bing": ("b", "ing"),
"bo": ("b", "o"),
"bu": ("b", "u"),
"ca": ("c", "a"),
"cai": ("c", "ai"),
"can": ("c", "an"),
"cang": ("c", "ang"),
"cao": ("c", "ao"),
"ce": ("c", "e"),
"cen": ("c", "en"),
"ceng": ("c", "eng"),
"cha": ("ch", "a"),
"chai": ("ch", "ai"),
"chan": ("ch", "an"),
"chang": ("ch", "ang"),
"chao": ("ch", "ao"),
"che": ("ch", "e"),
"chen": ("ch", "en"),
"cheng": ("ch", "eng"),
"chi": ("ch", "iii"),
"chong": ("ch", "ong"),
"chou": ("ch", "ou"),
"chu": ("ch", "u"),
"chua": ("ch", "ua"),
"chuai": ("ch", "uai"),
"chuan": ("ch", "uan"),
"chuang": ("ch", "uang"),
"chui": ("ch", "uei"),
"chun": ("ch", "uen"),
"chuo": ("ch", "uo"),
"ci": ("c", "ii"),
"cong": ("c", "ong"),
"cou": ("c", "ou"),
"cu": ("c", "u"),
"cuan": ("c", "uan"),
"cui": ("c", "uei"),
"cun": ("c", "uen"),
"cuo": ("c", "uo"),
"da": ("d", "a"),
"dai": ("d", "ai"),
"dan": ("d", "an"),
"dang": ("d", "ang"),
"dao": ("d", "ao"),
"de": ("d", "e"),
"dei": ("d", "ei"),
"den": ("d", "en"),
"deng": ("d", "eng"),
"di": ("d", "i"),
"dia": ("d", "ia"),
"dian": ("d", "ian"),
"diao": ("d", "iao"),
"die": ("d", "ie"),
"ding": ("d", "ing"),
"diu": ("d", "iou"),
"dong": ("d", "ong"),
"dou": ("d", "ou"),
"du": ("d", "u"),
"duan": ("d", "uan"),
"dui": ("d", "uei"),
"dun": ("d", "uen"),
"duo": ("d", "uo"),
"e": ("^", "e"),
"ei": ("^", "ei"),
"en": ("^", "en"),
"ng": ("^", "en"),
"eng": ("^", "eng"),
"er": ("^", "er"),
"fa": ("f", "a"),
"fan": ("f", "an"),
"fang": ("f", "ang"),
"fei": ("f", "ei"),
"fen": ("f", "en"),
"feng": ("f", "eng"),
"fo": ("f", "o"),
"fou": ("f", "ou"),
"fu": ("f", "u"),
"ga": ("g", "a"),
"gai": ("g", "ai"),
"gan": ("g", "an"),
"gang": ("g", "ang"),
"gao": ("g", "ao"),
"ge": ("g", "e"),
"gei": ("g", "ei"),
"gen": ("g", "en"),
"geng": ("g", "eng"),
"gong": ("g", "ong"),
"gou": ("g", "ou"),
"gu": ("g", "u"),
"gua": ("g", "ua"),
"guai": ("g", "uai"),
"guan": ("g", "uan"),
"guang": ("g", "uang"),
"gui": ("g", "uei"),
"gun": ("g", "uen"),
"guo": ("g", "uo"),
"ha": ("h", "a"),
"hai": ("h", "ai"),
"han": ("h", "an"),
"hang": ("h", "ang"),
"hao": ("h", "ao"),
"he": ("h", "e"),
"hei": ("h", "ei"),
"hen": ("h", "en"),
"heng": ("h", "eng"),
"hong": ("h", "ong"),
"hou": ("h", "ou"),
"hu": ("h", "u"),
"hua": ("h", "ua"),
"huai": ("h", "uai"),
"huan": ("h", "uan"),
"huang": ("h", "uang"),
"hui": ("h", "uei"),
"hun": ("h", "uen"),
"huo": ("h", "uo"),
"ji": ("j", "i"),
"jia": ("j", "ia"),
"jian": ("j", "ian"),
"jiang": ("j", "iang"),
"jiao": ("j", "iao"),
"jie": ("j", "ie"),
"jin": ("j", "in"),
"jing": ("j", "ing"),
"jiong": ("j", "iong"),
"jiu": ("j", "iou"),
"ju": ("j", "v"),
"juan": ("j", "van"),
"jue": ("j", "ve"),
"jun": ("j", "vn"),
"ka": ("k", "a"),
"kai": ("k", "ai"),
"kan": ("k", "an"),
"kang": ("k", "ang"),
"kao": ("k", "ao"),
"ke": ("k", "e"),
"kei": ("k", "ei"),
"ken": ("k", "en"),
"keng": ("k", "eng"),
"kong": ("k", "ong"),
"kou": ("k", "ou"),
"ku": ("k", "u"),
"kua": ("k", "ua"),
"kuai": ("k", "uai"),
"kuan": ("k", "uan"),
"kuang": ("k", "uang"),
"kui": ("k", "uei"),
"kun": ("k", "uen"),
"kuo": ("k", "uo"),
"la": ("l", "a"),
"lai": ("l", "ai"),
"lan": ("l", "an"),
"lang": ("l", "ang"),
"lao": ("l", "ao"),
"le": ("l", "e"),
"lei": ("l", "ei"),
"leng": ("l", "eng"),
"li": ("l", "i"),
"lia": ("l", "ia"),
"lian": ("l", "ian"),
"liang": ("l", "iang"),
"liao": ("l", "iao"),
"lie": ("l", "ie"),
"lin": ("l", "in"),
"ling": ("l", "ing"),
"liu": ("l", "iou"),
"lo": ("l", "o"),
"long": ("l", "ong"),
"lou": ("l", "ou"),
"lu": ("l", "u"),
"lv": ("l", "v"),
"luan": ("l", "uan"),
"lve": ("l", "ve"),
"lue": ("l", "ve"),
"lun": ("l", "uen"),
"luo": ("l", "uo"),
"ma": ("m", "a"),
"mai": ("m", "ai"),
"man": ("m", "an"),
"mang": ("m", "ang"),
"mao": ("m", "ao"),
"me": ("m", "e"),
"mei": ("m", "ei"),
"men": ("m", "en"),
"meng": ("m", "eng"),
"mi": ("m", "i"),
"mian": ("m", "ian"),
"miao": ("m", "iao"),
"mie": ("m", "ie"),
"min": ("m", "in"),
"ming": ("m", "ing"),
"miu": ("m", "iou"),
"mo": ("m", "o"),
"mou": ("m", "ou"),
"mu": ("m", "u"),
"na": ("n", "a"),
"nai": ("n", "ai"),
"nan": ("n", "an"),
"nang": ("n", "ang"),
"nao": ("n", "ao"),
"ne": ("n", "e"),
"nei": ("n", "ei"),
"nen": ("n", "en"),
"neng": ("n", "eng"),
"ni": ("n", "i"),
"nia": ("n", "ia"),
"nian": ("n", "ian"),
"niang": ("n", "iang"),
"niao": ("n", "iao"),
"nie": ("n", "ie"),
"nin": ("n", "in"),
"ning": ("n", "ing"),
"niu": ("n", "iou"),
"nong": ("n", "ong"),
"nou": ("n", "ou"),
"nu": ("n", "u"),
"nv": ("n", "v"),
"nuan": ("n", "uan"),
"nve": ("n", "ve"),
"nue": ("n", "ve"),
"nuo": ("n", "uo"),
"o": ("^", "o"),
"ou": ("^", "ou"),
"pa": ("p", "a"),
"pai": ("p", "ai"),
"pan": ("p", "an"),
"pang": ("p", "ang"),
"pao": ("p", "ao"),
"pe": ("p", "e"),
"pei": ("p", "ei"),
"pen": ("p", "en"),
"peng": ("p", "eng"),
"pi": ("p", "i"),
"pian": ("p", "ian"),
"piao": ("p", "iao"),
"pie": ("p", "ie"),
"pin": ("p", "in"),
"ping": ("p", "ing"),
"po": ("p", "o"),
"pou": ("p", "ou"),
"pu": ("p", "u"),
"qi": ("q", "i"),
"qia": ("q", "ia"),
"qian": ("q", "ian"),
"qiang": ("q", "iang"),
"qiao": ("q", "iao"),
"qie": ("q", "ie"),
"qin": ("q", "in"),
"qing": ("q", "ing"),
"qiong": ("q", "iong"),
"qiu": ("q", "iou"),
"qu": ("q", "v"),
"quan": ("q", "van"),
"que": ("q", "ve"),
"qun": ("q", "vn"),
"ran": ("r", "an"),
"rang": ("r", "ang"),
"rao": ("r", "ao"),
"re": ("r", "e"),
"ren": ("r", "en"),
"reng": ("r", "eng"),
"ri": ("r", "iii"),
"rong": ("r", "ong"),
"rou": ("r", "ou"),
"ru": ("r", "u"),
"rua": ("r", "ua"),
"ruan": ("r", "uan"),
"rui": ("r", "uei"),
"run": ("r", "uen"),
"ruo": ("r", "uo"),
"sa": ("s", "a"),
"sai": ("s", "ai"),
"san": ("s", "an"),
"sang": ("s", "ang"),
"sao": ("s", "ao"),
"se": ("s", "e"),
"sen": ("s", "en"),
"seng": ("s", "eng"),
"sha": ("sh", "a"),
"shai": ("sh", "ai"),
"shan": ("sh", "an"),
"shang": ("sh", "ang"),
"shao": ("sh", "ao"),
"she": ("sh", "e"),
"shei": ("sh", "ei"),
"shen": ("sh", "en"),
"sheng": ("sh", "eng"),
"shi": ("sh", "iii"),
"shou": ("sh", "ou"),
"shu": ("sh", "u"),
"shua": ("sh", "ua"),
"shuai": ("sh", "uai"),
"shuan": ("sh", "uan"),
"shuang": ("sh", "uang"),
"shui": ("sh", "uei"),
"shun": ("sh", "uen"),
"shuo": ("sh", "uo"),
"si": ("s", "ii"),
"song": ("s", "ong"),
"sou": ("s", "ou"),
"su": ("s", "u"),
"suan": ("s", "uan"),
"sui": ("s", "uei"),
"sun": ("s", "uen"),
"suo": ("s", "uo"),
"ta": ("t", "a"),
"tai": ("t", "ai"),
"tan": ("t", "an"),
"tang": ("t", "ang"),
"tao": ("t", "ao"),
"te": ("t", "e"),
"tei": ("t", "ei"),
"teng": ("t", "eng"),
"ti": ("t", "i"),
"tian": ("t", "ian"),
"tiao": ("t", "iao"),
"tie": ("t", "ie"),
"ting": ("t", "ing"),
"tong": ("t", "ong"),
"tou": ("t", "ou"),
"tu": ("t", "u"),
"tuan": ("t", "uan"),
"tui": ("t", "uei"),
"tun": ("t", "uen"),
"tuo": ("t", "uo"),
"wa": ("^", "ua"),
"wai": ("^", "uai"),
"wan": ("^", "uan"),
"wang": ("^", "uang"),
"wei": ("^", "uei"),
"wen": ("^", "uen"),
"weng": ("^", "ueng"),
"wo": ("^", "uo"),
"wu": ("^", "u"),
"xi": ("x", "i"),
"xia": ("x", "ia"),
"xian": ("x", "ian"),
"xiang": ("x", "iang"),
"xiao": ("x", "iao"),
"xie": ("x", "ie"),
"xin": ("x", "in"),
"xing": ("x", "ing"),
"xiong": ("x", "iong"),
"xiu": ("x", "iou"),
"xu": ("x", "v"),
"xuan": ("x", "van"),
"xue": ("x", "ve"),
"xun": ("x", "vn"),
"ya": ("^", "ia"),
"yan": ("^", "ian"),
"yang": ("^", "iang"),
"yao": ("^", "iao"),
"ye": ("^", "ie"),
"yi": ("^", "i"),
"yin": ("^", "in"),
"ying": ("^", "ing"),
"yo": ("^", "iou"),
"yong": ("^", "iong"),
"you": ("^", "iou"),
"yu": ("^", "v"),
"yuan": ("^", "van"),
"yue": ("^", "ve"),
"yun": ("^", "vn"),
"za": ("z", "a"),
"zai": ("z", "ai"),
"zan": ("z", "an"),
"zang": ("z", "ang"),
"zao": ("z", "ao"),
"ze": ("z", "e"),
"zei": ("z", "ei"),
"zen": ("z", "en"),
"zeng": ("z", "eng"),
"zha": ("zh", "a"),
"zhai": ("zh", "ai"),
"zhan": ("zh", "an"),
"zhang": ("zh", "ang"),
"zhao": ("zh", "ao"),
"zhe": ("zh", "e"),
"zhei": ("zh", "ei"),
"zhen": ("zh", "en"),
"zheng": ("zh", "eng"),
"zhi": ("zh", "iii"),
"zhong": ("zh", "ong"),
"zhou": ("zh", "ou"),
"zhu": ("zh", "u"),
"zhua": ("zh", "ua"),
"zhuai": ("zh", "uai"),
"zhuan": ("zh", "uan"),
"zhuang": ("zh", "uang"),
"zhui": ("zh", "uei"),
"zhun": ("zh", "uen"),
"zhuo": ("zh", "uo"),
"zi": ("z", "ii"),
"zong": ("z", "ong"),
"zou": ("z", "ou"),
"zu": ("z", "u"),
"zuan": ("z", "uan"),
"zui": ("z", "uei"),
"zun": ("z", "uen"),
"zuo": ("z", "uo"),
}

View File

@ -1,53 +0,0 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file generates the file that maps tokens to IDs.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict
from symbols import symbols
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--tokens",
type=Path,
default=Path("data/tokens.txt"),
help="Path to the dict that maps the text tokens to IDs",
)
return parser.parse_args()
def main():
args = get_args()
tokens = Path(args.tokens)
with open(tokens, "w", encoding="utf-8") as f:
for token_id, token in enumerate(symbols):
f.write(f"{token} {token_id}\n")
if __name__ == "__main__":
main()

View File

@ -1,59 +0,0 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file reads the texts in given manifest and save the new cuts with tokens.
"""
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest
from tokenizer import Tokenizer
def prepare_tokens_baker_zh():
output_dir = Path("data/spectrogram")
prefix = "baker_zh"
suffix = "jsonl.gz"
partition = "all"
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
tokenizer = Tokenizer()
new_cuts = []
i = 0
for cut in cut_set:
# Each cut only contains one supervision
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
text = cut.supervisions[0].normalized_text
cut.tokens = tokenizer.text_to_tokens(text)
new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts)
new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
prepare_tokens_baker_zh()

View File

@ -1,328 +0,0 @@
姐姐 jie3 jie
宝宝 bao3 bao
哥哥 ge1 ge
妹妹 mei4 mei
弟弟 di4 di
妈妈 ma1 ma
开心哦 kai1 xin1 o
爸爸 ba4 ba
秘密哟 mi4 mi4 yo
哦 o
一年 yi4 nian2
一夜 yi2 ye4
一切 yi2 qie4
一座 yi2 zuo4
一下 yi2 xia4
上一山 shang4 yi2 shan1
下一山 xia4 yi2 shan1
休息 xiu1 xi2
东西 dong1 xi
上一届 shang4 yi2 jie4
便宜 pian2 yi4
加长 jia1 chang2
单田芳 shan4 tian2 fang1
帧 zhen1
长时间 chang2 shi2 jian1
长时 chang2 shi2
识别 shi2 bie2
生命中 sheng1 ming4 zhong1
踏实 ta1 shi
嗯 en4
溜达 liu1 da
少儿 shao4 er2
爷爷 ye2 ye
不是 bu2 shi4
一圈 yi1 quan1
厜读一声 zui1 du2 yi4 sheng1
一种 yi4 zhong3
一簇簇 yi2 cu4 cu4
一个 yi2 ge4
一样 yi2 yang4
一跩一跩 yi4 zhuai3 yi4 zhuai3
一会儿 yi2 hui4 er
一幢 yi2 zhuang4
挨了 ai2 le
熬菜 ao1 cai4
扒鸡 pa2 ji1
背枪 bei1 qiang1
绷瓷儿 beng4 ci2 er2
绷劲儿 beng3 jin4 er
绷着脸 beng3 zhe lian3
藏医 zang4 yi1
噌吰 cheng1 hong2
差点儿 cha4 dian3 er
差失 cha1 shi1
差误 cha1 wu4
孱头 can4 tou
乘间 cheng2 jian4
锄镰棘矜 chu2 lian2 ji2 qin2
川藏 chuan1 zang4
穿著 chuan1 zhuo2
答讪 da1 shan4
答言 da1 yan2
大伯子 da4 bai3 zi
大夫 dai4 fu
弹冠 tan2 guan1
当间 dang1 jian4
当然咯 dang1 ran2 lo
点种 dian3 zhong3
垛好 duo4 hao3
发疟子 fa1 yao4 zi
饭熟了 fan4 shou2 le
附著 fu4 zhuo2
复沓 fu4 ta4
供稿 gong1 gao3
供养 gong1 yang3
骨朵 gu1 duo
骨碌 gu1 lu
果脯 guo3 fu3
哈什玛 ha4 shi2 ma3
海蜇 hai3 zhe2
呵欠 he1 qian
河水汤汤 he2 shui3 shang1 shang1
鹄立 hu2 li4
鹄望 hu2 wang4
混人 hun2 ren2
混水 hun2 shui3
鸡血 ji1 xie3
缉鞋口 qi1 xie2 kou3
亟来闻讯 qi4 lai2 wen2 xun4
计量 ji4 liang2
济水 ji3 shui3
间杂 jian4 za2
脚跐两只船 jiao3 ci3 liang3 zhi1 chuan2
脚儿 jue2 er2
口角 kou3 jiao3
勒石 le4 shi2
累进 lei3 jin4
累累如丧家之犬 lei2 lei2 ru2 sang4 jia1 zhi1 quan3
累年 lei3 nian2
脸涨通红 lian3 zhang4 tong1 hong2
踉锵 liang4 qiang1
燎眉毛 liao3 mei2 mao2
燎头发 liao3 tou2 fa4
溜达 liu1 da
溜缝儿 liu4 feng4 er
馏口饭 liu4 kou3 fan4
遛马 liu4 ma3
遛鸟 liu4 niao3
遛弯儿 liu4 wan1 er
楼枪机 lou1 qiang1 ji1
搂钱 lou1 qian2
鹿脯 lu4 fu3
露头 lou4 tou2
落魄 luo4 po4
捋胡子 lv3 hu2 zi
绿地 lv4 di4
麦垛 mai4 duo4
没劲儿 mei2 jin4 er
闷棍 men4 gun4
闷葫芦 men4 hu2 lu
闷头干 men1 tou2 gan4
蒙古 meng3 gu3
靡日不思 mi3 ri4 bu4 si1
缪姓 miao4 xing4
抹墙 mo4 qiang2
抹下脸 ma1 xia4 lian3
泥子 ni4 zi
拗不过 niu4 bu guo4
排车 pai3 che1
盘诘 pan2 jie2
膀肿 pang1 zhong3
炮干 bao1 gan1
炮格 pao2 ge2
碰钉子 peng4 ding1 zi
缥色 piao3 se4
瀑河 bao4 he2
蹊径 xi1 jing4
前后相属 qian2 hou4 xiang1 zhu3
翘尾巴 qiao4 wei3 ba
趄坡儿 qie4 po1 er
秦桧 qin2 hui4
圈马 juan1 ma3
雀盲眼 qiao3 mang2 yan3
雀子 qiao1 zi
三年五载 san1 nian2 wu3 zai3
加载 jia1 zai3
山大王 shan1 dai4 wang
苫屋草 shan4 wu1 cao3
数数 shu3 shu4
说客 shui4 ke4
思量 si1 liang2
伺侯 ci4 hou
踏实 ta1 shi
提溜 di1 liu
调拨 diao4 bo1
帖子 tie3 zi
铜钿 tong2 tian2
头昏脑涨 tou2 hun1 nao3 zhang4
褪色 tui4 se4
褪着手 tun4 zhe shou3
圩子 wei2 zi
尾巴 wei3 ba
系好船只 xi4 hao3 chuan2 zhi1
系好马匹 xi4 hao3 ma3 pi3
杏脯 xing4 fu3
姓单 xing4 shan4
姓葛 xing4 ge3
姓哈 xing4 ha3
姓解 xing4 xie4
姓秘 xing4 bi4
姓宁 xing4 ning4
旋风 xuan4 feng1
旋根车轴 xuan4 gen1 che1 zhou2
荨麻 qian2 ma2
一幢楼房 yi1 zhuang4 lou2 fang2
遗之千金 wei4 zhi1 qian1 jin1
殷殷 yin3 yin3
应招 ying4 zhao1
用称约 yong4 cheng4 yao1
约斤肉 yao1 jin1 rou4
晕机 yun4 ji1
熨贴 yu4 tie1
咋办 za3 ban4
咋呼 zha1 hu
仔兽 zi3 shou4
扎彩 za1 cai3
扎实 zha1 shi
扎腰带 za1 yao1 dai4
轧朋友 ga2 peng2 you3
爪子 zhua3 zi
折腾 zhe1 teng
着实 zhuo2 shi2
着我旧时裳 zhuo2 wo3 jiu4 shi2 chang2
枝蔓 zhi1 man4
中鹄 zhong1 hu2
中选 zhong4 xuan3
猪圈 zhu1 juan4
拽住不放 zhuai4 zhu4 bu4 fang4
转悠 zhuan4 you
庄稼熟了 zhuang1 jia shou2 le
酌量 zhuo2 liang2
罪行累累 zui4 xing2 lei3 lei3
一手 yi4 shou3
一去不复返 yi2 qu4 bu2 fu4 fan3
一颗 yi4 ke1
一件 yi2 jian4
一斤 yi4 jin1
一点 yi4 dian3
一朵 yi4 duo3
一声 yi4 sheng1
一身 yi4 shen1
不要 bu2 yao4
一人 yi4 ren2
一个 yi2 ge4
一把 yi4 ba3
一门 yi4 men2
一門 yi4 men2
一艘 yi4 sou1
一片 yi2 pian4
一篇 yi2 pian1
一份 yi2 fen4
好嗲 hao3 dia3
随地 sui2 di4
扁担长 bian3 dan4 chang3
一堆 yi4 dui1
不义 bu2 yi4
放一放 fang4 yi2 fang4
一米 yi4 mi3
一顿 yi2 dun4
一层楼 yi4 ceng2 lou2
一条 yi4 tiao2
一件 yi2 jian4
一棵 yi4 ke1
一小股 yi4 xiao3 gu3
一拐一拐 yi4 guai3 yi4 guai3
一根 yi4 gen1
沆瀣一气 hang4 xie4 yi2 qi4
一丝 yi4 si1
一毫 yi4 hao2
一樣 yi2 yang4
处处 chu4 chu4
一餐 yi4 can
永不 yong3 bu2
一看 yi2 kan4
一架 yi2 jia4
送还 song4 huan2
一见 yi2 jian4
一座 yi2 zuo4
一块 yi2 kuai4
一天 yi4 tian1
一只 yi4 zhi1
一支 yi4 zhi1
一字 yi2 zi4
一句 yi2 ju4
一张 yi4 zhang1
一條 yi4 tiao2
一场 yi4 chang3
一粒 yi2 li4
小俩口 xiao3 liang3 kou3
一首 yi4 shou3
一对 yi2 dui4
一手 yi4 shou3
又一村 you4 yi4 cun1
一概而论 yi2 gai4 er2 lun4
一峰峰 yi4 feng1 feng1
不但 bu2 dan4
一笑 yi2 xiao4
挠痒痒 nao2 yang3 yang
不对 bu2 dui4
拧开 ning3 kai1
爱不释手 ai4 bu2 shi4 shou3
一念 yi2 nian4
夺得 duo2 de2
一袭 yi4 xi2
一定 yi2 ding4
不慎 bu2 shen4
剽窃 piao2 qie4
一时 yi4 shi2
撇开 pie3 kai1
一祭 yi2 ji4
发卡 fa4 qia3
少不了 shao3 bu4 liao3
千虑一失 qian1 lv4 yi4 shi1
呛得 qiang4 de2
切菜 qie1 cai4
茄盒 qie2 he2
不去 bu2 qu4
一大圈 yi2 da4 quan1
不再 bu2 zai4
一群 yi4 qun2
不必 bu2 bi4
一些 yi4 xie1
一路 yi2 lu4
一股 yi4 gu3
一到 yi2 dao4
一拨 yi4 bo1
一排 yi4 pai2
一空 yi4 kong1
吮吸着 shun3 xi1 zhe
不适合 bu2 shi4 he2
一串串 yi2 chuan4 chuan4
一提起 yi4 ti2 qi3
一尘不染 yi4 chen2 bu4 ran3
一生 yi4 sheng1
一派 yi2 pai4
不断 bu2 duan4
一次 yi2 ci4
不进步 bu2 jin4 bu4
娃娃 wa2 wa
万户侯 wan4 hu4 hou2
一方 yi4 fang1
一番话 yi4 fan1 hua4
一遍 yi2 bian4
不计较 bu2 ji4 jiao4
诇 xiong4
一边 yi4 bian1
一束 yi2 shu4
一听到 yi4 ting1 dao4
炸鸡 zha2 ji1
乍暧还寒 zha4 ai4 huan2 han2
我说诶 wo3 shuo1 ei1
棒诶 bang4 ei1
寒碜 han2 chen4
应采儿 ying4 cai3 er2
晕车 yun1 che1
必应 bi4 ying4
应援 ying4 yuan2
应力 ying4 li4

View File

@ -1,73 +0,0 @@
# This file is copied from
# https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py
_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
_initials = [
"^",
"b",
"c",
"ch",
"d",
"f",
"g",
"h",
"j",
"k",
"l",
"m",
"n",
"p",
"q",
"r",
"s",
"sh",
"t",
"x",
"z",
"zh",
]
_tones = ["1", "2", "3", "4", "5"]
_finals = [
"a",
"ai",
"an",
"ang",
"ao",
"e",
"ei",
"en",
"eng",
"er",
"i",
"ia",
"ian",
"iang",
"iao",
"ie",
"ii",
"iii",
"in",
"ing",
"iong",
"iou",
"o",
"ong",
"ou",
"u",
"ua",
"uai",
"uan",
"uang",
"uei",
"uen",
"ueng",
"uo",
"v",
"van",
"ve",
"vn",
]
symbols = _pause + _initials + [i + j for i in _finals for j in _tones]

View File

@ -1,137 +0,0 @@
# This file is modified from
# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py
import logging
from pathlib import Path
from typing import List
# Note pinyin_dict is from ./pinyin_dict.py
from pinyin_dict import pinyin_dict
from pypinyin import Style
from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin
from pypinyin.converter import DefaultConverter
from pypinyin.core import Pinyin, load_phrases_dict
class _MyConverter(NeutralToneWith5Mixin, DefaultConverter):
pass
class Tokenizer:
def __init__(self, tokens: str = ""):
self._load_pinyin_dict()
self._pinyin_parser = Pinyin(_MyConverter())
if tokens != "":
self._load_tokens(tokens)
def texts_to_token_ids(self, texts: List[str], **kwargs) -> List[List[int]]:
"""
Args:
texts:
A list of sentences.
kwargs:
Not used. It is for compatibility with other TTS recipes in icefall.
"""
tokens = []
for text in texts:
tokens.append(self.text_to_tokens(text))
return self.tokens_to_token_ids(tokens)
def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]:
ans = []
for token_list in tokens:
token_ids = []
for t in token_list:
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
continue
token_ids.append(self.token2id[t])
ans.append(token_ids)
return ans
def text_to_tokens(self, text: str) -> List[str]:
# Convert "" to ["sp", "sil"]
# Convert "。" to ["sil"]
# append ["eos"] at the end of a sentence
phonemes = ["sil"]
pinyins = self._pinyin_parser.pinyin(
text,
style=Style.TONE3,
errors=lambda x: [[w] for w in x],
)
new_pinyin = []
for p in pinyins:
p = p[0]
if p == "":
new_pinyin.extend(["sp", "sil"])
elif p == "":
new_pinyin.append("sil")
else:
new_pinyin.append(p)
sub_phonemes = self._get_phoneme4pinyin(new_pinyin)
sub_phonemes.append("eos")
phonemes.extend(sub_phonemes)
return phonemes
def _get_phoneme4pinyin(self, pinyins):
result = []
for pinyin in pinyins:
if pinyin in ("sil", "sp"):
result.append(pinyin)
elif pinyin[:-1] in pinyin_dict:
tone = pinyin[-1]
a = pinyin[:-1]
a1, a2 = pinyin_dict[a]
# every word is appended with a #0
result += [a1, a2 + tone, "#0"]
return result
def _load_pinyin_dict(self):
this_dir = Path(__file__).parent.resolve()
my_dict = {}
with open(f"{this_dir}/pypinyin-local.dict", "r", encoding="utf-8") as f:
content = f.readlines()
for line in content:
cuts = line.strip().split()
hanzi = cuts[0]
pinyin = cuts[1:]
my_dict[hanzi] = [[p] for p in pinyin]
load_phrases_dict(my_dict)
def _load_tokens(self, filename):
token2id: Dict[str, int] = {}
with open(filename, "r", encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
if len(info) == 1:
# case of space
token = " "
idx = int(info[0])
else:
token, idx = info[0], int(info[1])
assert token not in token2id, token
token2id[token] = idx
self.token2id = token2id
self.vocab_size = len(self.token2id)
self.pad_id = self.token2id["#0"]
def main():
tokenizer = Tokenizer()
tokenizer._sentence_to_ids("你好,好的。")
if __name__ == "__main__":
main()

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/local/validate_manifest.py

View File

@ -1,124 +0,0 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=-1
stop_stage=100
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: build monotonic_align lib"
if [ ! -d vits/monotonic_align/build ]; then
cd vits/monotonic_align
python3 setup.py build_ext --inplace
cd ../../
else
log "monotonic_align lib already built"
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Download data"
# The directory $dl_dir/BZNSYP will contain 3 sub directories:
# - PhoneLabeling
# - ProsodyLabeling
# - Wave
# If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink
#
# ln -sfv /path/to/BZNSYP $dl_dir/
# touch $dl_dir/BZNSYP/.completed
#
if [ ! -d $dl_dir/BZNSYP ]; then
lhotse download baker-zh $dl_dir
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare baker-zh manifest"
# We assume that you have downloaded the baker corpus
# to $dl_dir/BZNSYP
mkdir -p data/manifests
if [ ! -e data/manifests/.baker.done ]; then
lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests
touch data/manifests/.baker.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Compute spectrogram for baker (may take 3 minutes)"
mkdir -p data/spectrogram
if [ ! -e data/spectrogram/.baker.done ]; then
./local/compute_spectrogram_baker.py
touch data/spectrogram/.baker.done
fi
if [ ! -e data/spectrogram/.baker-validated.done ]; then
log "Validating data/spectrogram for baker"
python3 ./local/validate_manifest.py \
data/spectrogram/baker_zh_cuts_all.jsonl.gz
touch data/spectrogram/.baker-validated.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Prepare tokens for baker-zh (may take 20 seconds)"
if [ ! -e data/spectrogram/.baker_zh_with_token.done ]; then
./local/prepare_tokens_baker_zh.py
mv -v data/spectrogram/baker_zh_cuts_with_tokens_all.jsonl.gz \
data/spectrogram/baker_zh_cuts_all.jsonl.gz
touch data/spectrogram/.baker_zh_with_token.done
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Split the baker-zh cuts into train, valid and test sets (may take 25 seconds)"
if [ ! -e data/spectrogram/.baker_zh_split.done ]; then
lhotse subset --last 600 \
data/spectrogram/baker_zh_cuts_all.jsonl.gz \
data/spectrogram/baker_zh_cuts_validtest.jsonl.gz
lhotse subset --first 100 \
data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \
data/spectrogram/baker_zh_cuts_valid.jsonl.gz
lhotse subset --last 500 \
data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \
data/spectrogram/baker_zh_cuts_test.jsonl.gz
rm data/spectrogram/baker_zh_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/spectrogram/baker_zh_cuts_all.jsonl.gz | wc -l) - 600 ))
lhotse subset --first $n \
data/spectrogram/baker_zh_cuts_all.jsonl.gz \
data/spectrogram/baker_zh_cuts_train.jsonl.gz
touch data/spectrogram/.baker_zh_split.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Generate token file"
if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py --tokens data/tokens.txt
fi
fi

View File

@ -1 +0,0 @@
../../../icefall/shared

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/duration_predictor.py

View File

@ -1,414 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script exports a VITS model from PyTorch to ONNX.
Export the model to ONNX:
./vits/export-onnx.py \
--epoch 1000 \
--exp-dir vits/exp \
--tokens data/tokens.txt
It will generate one file inside vits/exp:
- vits-epoch-1000.onnx
See ./test_onnx.py for how to use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import torch
import torch.nn as nn
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=1000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--model-type",
type=str,
default="high",
choices=["low", "medium", "high"],
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class OnnxModel(nn.Module):
"""A wrapper for VITS generator."""
def __init__(self, model: nn.Module):
"""
Args:
model:
A VITS generator.
frame_shift:
The frame shift in samples.
"""
super().__init__()
self.model = model
def forward(
self,
tokens: torch.Tensor,
tokens_lens: torch.Tensor,
noise_scale: float = 0.667,
alpha: float = 1.0,
noise_scale_dur: float = 0.8,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of VITS.inference_batch
Args:
tokens:
Input text token indexes (1, T_text)
tokens_lens:
Number of tokens of shape (1,)
noise_scale (float):
Noise scale parameter for flow.
noise_scale_dur (float):
Noise scale parameter for duration predictor.
alpha (float):
Alpha parameter to control the speed of generated speech.
Returns:
Return a tuple containing:
- audio, generated wavform tensor, (B, T_wav)
"""
audio, _, _ = self.model.generator.inference(
text=tokens,
text_lengths=tokens_lens,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
)
return audio
def export_model_onnx(
model: nn.Module,
model_filename: str,
vocab_size: int,
opset_version: int = 11,
) -> None:
"""Export the given generator model to ONNX format.
The exported model has one input:
- tokens, a tensor of shape (1, T_text); dtype is torch.int64
and it has one output:
- audio, a tensor of shape (1, T'); dtype is torch.float32
Args:
model:
The VITS generator.
model_filename:
The filename to save the exported ONNX model.
vocab_size:
Number of tokens used in training.
opset_version:
The opset version to use.
"""
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1], dtype=torch.float32)
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
alpha = torch.tensor([1], dtype=torch.float32)
torch.onnx.export(
model,
(tokens, tokens_lens, noise_scale, alpha, noise_scale_dur),
model_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"tokens",
"tokens_lens",
"noise_scale",
"alpha",
"noise_scale_dur",
],
output_names=["audio"],
dynamic_axes={
"tokens": {0: "N", 1: "T"},
"tokens_lens": {0: "N"},
"audio": {0: "N", 1: "T"},
},
)
if model.model.spks is None:
num_speakers = 1
else:
num_speakers = model.model.spks
meta_data = {
"model_type": "vits",
"version": "1",
"model_author": "k2-fsa",
"comment": "icefall", # must be icefall for models from icefall
"language": "Chinese",
"n_speakers": num_speakers,
"sample_rate": model.model.sampling_rate, # Must match the real sample rate
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=model_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to("cpu")
model.eval()
model = OnnxModel(model=model)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
suffix = f"epoch-{params.epoch}"
opset_version = 13
logging.info("Exporting encoder")
model_filename = params.exp_dir / f"vits-{suffix}.onnx"
export_model_onnx(
model,
model_filename,
params.vocab_size,
opset_version=opset_version,
)
logging.info(f"Exported generator to {model_filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
"""
Supported languages.
LJSpeech is using "en-us" from the second column.
Pty Language Age/Gender VoiceName File Other Languages
5 af --/M Afrikaans gmw/af
5 am --/M Amharic sem/am
5 an --/M Aragonese roa/an
5 ar --/M Arabic sem/ar
5 as --/M Assamese inc/as
5 az --/M Azerbaijani trk/az
5 ba --/M Bashkir trk/ba
5 be --/M Belarusian zle/be
5 bg --/M Bulgarian zls/bg
5 bn --/M Bengali inc/bn
5 bpy --/M Bishnupriya_Manipuri inc/bpy
5 bs --/M Bosnian zls/bs
5 ca --/M Catalan roa/ca
5 chr-US-Qaaa-x-west --/M Cherokee_ iro/chr
5 cmn --/M Chinese_(Mandarin,_latin_as_English) sit/cmn (zh-cmn 5)(zh 5)
5 cmn-latn-pinyin --/M Chinese_(Mandarin,_latin_as_Pinyin) sit/cmn-Latn-pinyin (zh-cmn 5)(zh 5)
5 cs --/M Czech zlw/cs
5 cv --/M Chuvash trk/cv
5 cy --/M Welsh cel/cy
5 da --/M Danish gmq/da
5 de --/M German gmw/de
5 el --/M Greek grk/el
5 en-029 --/M English_(Caribbean) gmw/en-029 (en 10)
2 en-gb --/M English_(Great_Britain) gmw/en (en 2)
5 en-gb-scotland --/M English_(Scotland) gmw/en-GB-scotland (en 4)
5 en-gb-x-gbclan --/M English_(Lancaster) gmw/en-GB-x-gbclan (en-gb 3)(en 5)
5 en-gb-x-gbcwmd --/M English_(West_Midlands) gmw/en-GB-x-gbcwmd (en-gb 9)(en 9)
5 en-gb-x-rp --/M English_(Received_Pronunciation) gmw/en-GB-x-rp (en-gb 4)(en 5)
2 en-us --/M English_(America) gmw/en-US (en 3)
5 en-us-nyc --/M English_(America,_New_York_City) gmw/en-US-nyc
5 eo --/M Esperanto art/eo
5 es --/M Spanish_(Spain) roa/es
5 es-419 --/M Spanish_(Latin_America) roa/es-419 (es-mx 6)
5 et --/M Estonian urj/et
5 eu --/M Basque eu
5 fa --/M Persian ira/fa
5 fa-latn --/M Persian_(Pinglish) ira/fa-Latn
5 fi --/M Finnish urj/fi
5 fr-be --/M French_(Belgium) roa/fr-BE (fr 8)
5 fr-ch --/M French_(Switzerland) roa/fr-CH (fr 8)
5 fr-fr --/M French_(France) roa/fr (fr 5)
5 ga --/M Gaelic_(Irish) cel/ga
5 gd --/M Gaelic_(Scottish) cel/gd
5 gn --/M Guarani sai/gn
5 grc --/M Greek_(Ancient) grk/grc
5 gu --/M Gujarati inc/gu
5 hak --/M Hakka_Chinese sit/hak
5 haw --/M Hawaiian map/haw
5 he --/M Hebrew sem/he
5 hi --/M Hindi inc/hi
5 hr --/M Croatian zls/hr (hbs 5)
5 ht --/M Haitian_Creole roa/ht
5 hu --/M Hungarian urj/hu
5 hy --/M Armenian_(East_Armenia) ine/hy (hy-arevela 5)
5 hyw --/M Armenian_(West_Armenia) ine/hyw (hy-arevmda 5)(hy 8)
5 ia --/M Interlingua art/ia
5 id --/M Indonesian poz/id
5 io --/M Ido art/io
5 is --/M Icelandic gmq/is
5 it --/M Italian roa/it
5 ja --/M Japanese jpx/ja
5 jbo --/M Lojban art/jbo
5 ka --/M Georgian ccs/ka
5 kk --/M Kazakh trk/kk
5 kl --/M Greenlandic esx/kl
5 kn --/M Kannada dra/kn
5 ko --/M Korean ko
5 kok --/M Konkani inc/kok
5 ku --/M Kurdish ira/ku
5 ky --/M Kyrgyz trk/ky
5 la --/M Latin itc/la
5 lb --/M Luxembourgish gmw/lb
5 lfn --/M Lingua_Franca_Nova art/lfn
5 lt --/M Lithuanian bat/lt
5 ltg --/M Latgalian bat/ltg
5 lv --/M Latvian bat/lv
5 mi --/M Māori poz/mi
5 mk --/M Macedonian zls/mk
5 ml --/M Malayalam dra/ml
5 mr --/M Marathi inc/mr
5 ms --/M Malay poz/ms
5 mt --/M Maltese sem/mt
5 mto --/M Totontepec_Mixe miz/mto
5 my --/M Myanmar_(Burmese) sit/my
5 nb --/M Norwegian_Bokmål gmq/nb (no 5)
5 nci --/M Nahuatl_(Classical) azc/nci
5 ne --/M Nepali inc/ne
5 nl --/M Dutch gmw/nl
5 nog --/M Nogai trk/nog
5 om --/M Oromo cus/om
5 or --/M Oriya inc/or
5 pa --/M Punjabi inc/pa
5 pap --/M Papiamento roa/pap
5 piqd --/M Klingon art/piqd
5 pl --/M Polish zlw/pl
5 pt --/M Portuguese_(Portugal) roa/pt (pt-pt 5)
5 pt-br --/M Portuguese_(Brazil) roa/pt-BR (pt 6)
5 py --/M Pyash art/py
5 qdb --/M Lang_Belta art/qdb
5 qu --/M Quechua qu
5 quc --/M K'iche' myn/quc
5 qya --/M Quenya art/qya
5 ro --/M Romanian roa/ro
5 ru --/M Russian zle/ru
5 ru-cl --/M Russian_(Classic) zle/ru-cl
2 ru-lv --/M Russian_(Latvia) zle/ru-LV
5 sd --/M Sindhi inc/sd
5 shn --/M Shan_(Tai_Yai) tai/shn
5 si --/M Sinhala inc/si
5 sjn --/M Sindarin art/sjn
5 sk --/M Slovak zlw/sk
5 sl --/M Slovenian zls/sl
5 smj --/M Lule_Saami urj/smj
5 sq --/M Albanian ine/sq
5 sr --/M Serbian zls/sr
5 sv --/M Swedish gmq/sv
5 sw --/M Swahili bnt/sw
5 ta --/M Tamil dra/ta
5 te --/M Telugu dra/te
5 th --/M Thai tai/th
5 tk --/M Turkmen trk/tk
5 tn --/M Setswana bnt/tn
5 tr --/M Turkish trk/tr
5 tt --/M Tatar trk/tt
5 ug --/M Uyghur trk/ug
5 uk --/M Ukrainian zle/uk
5 ur --/M Urdu inc/ur
5 uz --/M Uzbek trk/uz
5 vi --/M Vietnamese_(Northern) aav/vi
5 vi-vn-x-central --/M Vietnamese_(Central) aav/vi-VN-x-central
5 vi-vn-x-south --/M Vietnamese_(Southern) aav/vi-VN-x-south
5 yue --/M Chinese_(Cantonese) sit/yue (zh-yue 5)(zh 8)
5 yue --/M Chinese_(Cantonese,_latin_as_Jyutping) sit/yue-Latn-jyutping (zh-yue 5)(zh 8)
"""

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/flow.py

View File

@ -1,39 +0,0 @@
#!/usr/bin/env python3
from pypinyin import phrases_dict, pinyin_dict
from tokenizer import Tokenizer
def main():
filename = "lexicon.txt"
tokens = "./data/tokens.txt"
tokenizer = Tokenizer(tokens)
word_dict = pinyin_dict.pinyin_dict
phrases = phrases_dict.phrases_dict
i = 0
with open(filename, "w", encoding="utf-8") as f:
for key in word_dict:
if not (0x4E00 <= key <= 0x9FFF):
continue
w = chr(key)
# 1 to remove the initial sil
# :-1 to remove the final eos
tokens = tokenizer.text_to_tokens(w)[1:-1]
tokens = " ".join(tokens)
f.write(f"{w} {tokens}\n")
for key in phrases:
# 1 to remove the initial sil
# :-1 to remove the final eos
tokens = tokenizer.text_to_tokens(key)[1:-1]
tokens = " ".join(tokens)
f.write(f"{key} {tokens}\n")
if __name__ == "__main__":
main()

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/generator.py

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/hifigan.py

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/loss.py

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/monotonic_align

View File

@ -1 +0,0 @@
../local/pinyin_dict.py

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/posterior_encoder.py

View File

@ -1 +0,0 @@
../local/pypinyin-local.dict

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/residual_coupling.py

View File

@ -1,142 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is used to test the exported onnx model by vits/export-onnx.py
Use the onnx model to generate a wav:
./vits/test_onnx.py \
--model-filename vits/exp/vits-epoch-1000.onnx \
--tokens data/tokens.txt
"""
import argparse
import logging
import onnxruntime as ort
import torch
import torchaudio
from tokenizer import Tokenizer
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-filename",
type=str,
required=True,
help="Path to the onnx model.",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--text",
type=str,
default="Ask not what your country can do for you; ask what you can do for your country.",
help="Text to generate speech for",
)
parser.add_argument(
"--output-filename",
type=str,
default="test_onnx.wav",
help="Filename to save the generated wave file.",
)
return parser
class OnnxModel:
def __init__(self, model_filename: str):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
metadata = self.model.get_modelmeta().custom_metadata_map
self.sample_rate = int(metadata["sample_rate"])
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
"""
Args:
tokens:
A 1-D tensor of shape (1, T)
Returns:
A tensor of shape (1, T')
"""
noise_scale = torch.tensor([0.667], dtype=torch.float32)
noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
alpha = torch.tensor([1.0], dtype=torch.float32)
out = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: tokens.numpy(),
self.model.get_inputs()[1].name: tokens_lens.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: alpha.numpy(),
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
},
)[0]
return torch.from_numpy(out)
def main():
args = get_parser().parse_args()
logging.info(vars(args))
tokenizer = Tokenizer(args.tokens)
logging.info("About to create onnx model")
model = OnnxModel(args.model_filename)
text = args.text
tokens = tokenizer.texts_to_token_ids([text])
tokens = torch.tensor(tokens) # (1, T)
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
audio = model(tokens, tokens_lens) # (1, T')
output_filename = args.output_filename
torchaudio.save(output_filename, audio, sample_rate=model.sample_rate)
logging.info(f"Saved to {output_filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/text_encoder.py

View File

@ -1 +0,0 @@
../local/tokenizer.py

View File

@ -1,927 +0,0 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import k2
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import BakerZhSpeechTtsDataModule
from utils import MetricsTracker, plot_feature, save_checkpoint
from vits import VITS
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.utils import AttributeDict, setup_logger, str2bool
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=1000,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--lr", type=float, default=2.0e-4, help="The base learning rate."
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--inf-check",
type=str2bool,
default=False,
help="Add hooks to check for infinite module outputs and gradients.",
)
parser.add_argument(
"--save-every-n",
type=int,
default=20,
help="""Save checkpoint after processing this number of epochs"
periodically. We save checkpoint to exp-dir/ whenever
params.cur_epoch % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
Since it will take around 1000 epochs, we suggest using a large
save_every_n to save disk space.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
parser.add_argument(
"--model-type",
type=str,
default="high",
choices=["low", "medium", "high"],
help="""If not empty, valid values are: low, medium, high.
It controls the model size. low -> runs faster.
""",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
"""
params = AttributeDict(
{
# training params
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": -1, # 0
"log_interval": 50,
"valid_interval": 200,
"env_info": get_env_info(),
"sampling_rate": 48000,
"frame_shift": 256,
"frame_length": 1024,
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
"n_mels": 80,
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
"lambda_mel": 45.0, # loss scaling coefficient for Mel loss
"lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
"lambda_dur": 1.0, # loss scaling coefficient for duration loss
"lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict, model: nn.Module
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
Returns:
Return a dict containing previously saved training info.
"""
if params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint(filename, model=model)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def get_model(params: AttributeDict) -> nn.Module:
mel_loss_params = {
"n_mels": params.n_mels,
"frame_length": params.frame_length,
"frame_shift": params.frame_shift,
}
model = VITS(
vocab_size=params.vocab_size,
feature_dim=params.feature_dim,
sampling_rate=params.sampling_rate,
model_type=params.model_type,
mel_loss_params=mel_loss_params,
lambda_adv=params.lambda_adv,
lambda_mel=params.lambda_mel,
lambda_feat_match=params.lambda_feat_match,
lambda_dur=params.lambda_dur,
lambda_kl=params.lambda_kl,
)
return model
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
"""Parse batch data"""
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(tokens)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
return audio, audio_lens, features, features_lens, tokens, tokens_lens
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
optimizer_g: Optimizer,
optimizer_d: Optimizer,
scheduler_g: LRSchedulerType,
scheduler_d: LRSchedulerType,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
tokenizer:
Used to convert text to phonemes.
optimizer_g:
The optimizer for generator.
optimizer_d:
The optimizer for discriminator.
scheduler_g:
The learning rate scheduler for generator, we call step() every epoch.
scheduler_d:
The learning rate scheduler for discriminator, we call step() every epoch.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
params=params,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["tokens"])
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
batch, tokenizer, device
)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
try:
with autocast(enabled=params.use_fp16):
# forward discriminator
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
)
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# update discriminator
optimizer_d.zero_grad()
scaler.scale(loss_d).backward()
scaler.step(optimizer_d)
with autocast(enabled=params.use_fp16):
# forward generator
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0,
)
for k, v in stats_g.items():
if "returned_sample" not in k:
loss_info[k] = v * batch_size
# update generator
optimizer_g.zero_grad()
scaler.scale(loss_g).backward()
scaler.step(optimizer_g)
scaler.update()
# summary stats
tot_loss = tot_loss + loss_info
except: # noqa
save_bad_model()
raise
if params.print_diagnostics and batch_idx == 5:
return
if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if params.batch_idx_train % params.log_interval == 0:
cur_lr_g = max(scheduler_g.get_last_lr())
cur_lr_d = max(scheduler_d.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], "
f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate_g", cur_lr_g, params.batch_idx_train
)
tb_writer.add_scalar(
"train/learning_rate_d", cur_lr_d, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if "returned_sample" in stats_g:
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
tb_writer.add_audio(
"train/speech_hat_",
speech_hat_,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/speech_",
speech_,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_image(
"train/mel_hat_",
plot_feature(mel_hat_),
params.batch_idx_train,
dataformats="HWC",
)
tb_writer.add_image(
"train/mel_",
plot_feature(mel_),
params.batch_idx_train,
dataformats="HWC",
)
if (
params.batch_idx_train % params.valid_interval == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss")
valid_info, (speech_hat, speech) = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
tb_writer.add_audio(
"train/valdi_speech_hat",
speech_hat,
params.batch_idx_train,
params.sampling_rate,
)
tb_writer.add_audio(
"train/valdi_speech",
speech,
params.batch_idx_train,
params.sampling_rate,
)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
# used to summary the stats over iterations
tot_loss = MetricsTracker()
returned_sample = None
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
batch_size = len(batch["tokens"])
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device)
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
# forward discriminator
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
)
assert loss_d.requires_grad is False
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# forward generator
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
)
assert loss_g.requires_grad is False
for k, v in stats_g.items():
loss_info[k] = v * batch_size
# summary stats
tot_loss = tot_loss + loss_info
# infer for first batch:
if batch_idx == 0 and rank == 0:
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(
text=tokens[0, : tokens_lens[0].item()]
)
audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
)
assert audio_len_pred == len(audio_pred), (
audio_len_pred,
len(audio_pred),
)
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
returned_sample = (audio_pred, audio_gt)
if world_size > 1:
tot_loss.reduce(device)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss, returned_sample
def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
tokenizer: Tokenizer,
optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
batch, tokenizer, device
)
try:
# for discriminator
with autocast(enabled=params.use_fp16):
loss_d, stats_d = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=False,
)
optimizer_d.zero_grad()
loss_d.backward()
# for generator
with autocast(enabled=params.use_fp16):
loss_g, stats_g = model(
text=tokens,
text_lengths=tokens_lens,
feats=features,
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
)
optimizer_g.zero_grad()
loss_g.backward()
except Exception as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
logging.info(params)
logging.info("About to create model")
model = get_model(params)
generator = model.generator
discriminator = model.discriminator
num_param_g = sum([p.numel() for p in generator.parameters()])
logging.info(f"Number of parameters in generator: {num_param_g}")
num_param_d = sum([p.numel() for p in discriminator.parameters()])
logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer_g = torch.optim.AdamW(
generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
)
optimizer_d = torch.optim.AdamW(
discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
if checkpoints is not None:
# load state_dict for optimizers
if "optimizer_g" in checkpoints:
logging.info("Loading optimizer_g state dict")
optimizer_g.load_state_dict(checkpoints["optimizer_g"])
if "optimizer_d" in checkpoints:
logging.info("Loading optimizer_d state dict")
optimizer_d.load_state_dict(checkpoints["optimizer_d"])
# load state_dict for schedulers
if "scheduler_g" in checkpoints:
logging.info("Loading scheduler_g state dict")
scheduler_g.load_state_dict(checkpoints["scheduler_g"])
if "scheduler_d" in checkpoints:
logging.info("Loading scheduler_d state dict")
scheduler_d.load_state_dict(checkpoints["scheduler_d"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
if params.inf_check:
register_inf_check_hooks(model)
baker_zh = BakerZhSpeechTtsDataModule(args)
train_cuts = baker_zh.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 20.0:
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = baker_zh.train_dataloaders(train_cuts)
valid_cuts = baker_zh.valid_cuts()
valid_dl = baker_zh.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
tokenizer=tokenizer,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
train_one_epoch(
params=params,
model=model,
tokenizer=tokenizer,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
# step per epoch
scheduler_g.step()
scheduler_d.step()
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
BakerZhSpeechTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/transform.py

View File

@ -1,330 +0,0 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class BakerZhSpeechTtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in TTS tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
self.sampling_rate = 48000
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/spectrogram"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = self.sampling_rate
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
)
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = self.sampling_rate
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
)
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = self.sampling_rate
config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=1024 / sampling_rate, # (in second),
frame_shift=256 / sampling_rate, # (in second)
use_fft_mag=True,
)
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(
self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz"
)

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/utils.py

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/vits.py

View File

@ -1 +0,0 @@
../../../ljspeech/TTS/vits/wavenet.py