mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
remove baker-zh
This commit is contained in:
parent
f9bd5ced9d
commit
35578f0593
@ -1,7 +0,0 @@
|
||||
# Introduction
|
||||
|
||||
[./symbols.py](./symbols.py) is copied from
|
||||
https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py
|
||||
|
||||
[./pypinyin-local.dict](./pypinyin-local.dict) is copied from
|
||||
https://github.com/UEhQZXI/vits_chinese/blob/master/misc/pypinyin-local.dict
|
@ -1,106 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the baker_zh dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated spectrogram features are saved in data/spectrogram.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
LilcomChunkyWriter,
|
||||
Spectrogram,
|
||||
SpectrogramConfig,
|
||||
load_manifest,
|
||||
)
|
||||
from lhotse.audio import RecordingSet
|
||||
from lhotse.supervision import SupervisionSet
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_spectrogram_baker_zh():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/spectrogram")
|
||||
num_jobs = min(4, os.cpu_count())
|
||||
|
||||
sampling_rate = 48000
|
||||
frame_length = 1024 / sampling_rate # (in second)
|
||||
frame_shift = 256 / sampling_rate # (in second)
|
||||
use_fft_mag = True
|
||||
|
||||
prefix = "baker_zh"
|
||||
suffix = "jsonl.gz"
|
||||
partition = "all"
|
||||
|
||||
recordings = load_manifest(
|
||||
src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet
|
||||
)
|
||||
supervisions = load_manifest(
|
||||
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
|
||||
)
|
||||
|
||||
config = SpectrogramConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
use_fft_mag=use_fft_mag,
|
||||
)
|
||||
extractor = Spectrogram(config)
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{cuts_filename} already exists - skipping.")
|
||||
return
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=recordings, supervisions=supervisions
|
||||
)
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
compute_spectrogram_baker_zh()
|
@ -1,421 +0,0 @@
|
||||
# This dict is copied from
|
||||
# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py
|
||||
pinyin_dict = {
|
||||
"a": ("^", "a"),
|
||||
"ai": ("^", "ai"),
|
||||
"an": ("^", "an"),
|
||||
"ang": ("^", "ang"),
|
||||
"ao": ("^", "ao"),
|
||||
"ba": ("b", "a"),
|
||||
"bai": ("b", "ai"),
|
||||
"ban": ("b", "an"),
|
||||
"bang": ("b", "ang"),
|
||||
"bao": ("b", "ao"),
|
||||
"be": ("b", "e"),
|
||||
"bei": ("b", "ei"),
|
||||
"ben": ("b", "en"),
|
||||
"beng": ("b", "eng"),
|
||||
"bi": ("b", "i"),
|
||||
"bian": ("b", "ian"),
|
||||
"biao": ("b", "iao"),
|
||||
"bie": ("b", "ie"),
|
||||
"bin": ("b", "in"),
|
||||
"bing": ("b", "ing"),
|
||||
"bo": ("b", "o"),
|
||||
"bu": ("b", "u"),
|
||||
"ca": ("c", "a"),
|
||||
"cai": ("c", "ai"),
|
||||
"can": ("c", "an"),
|
||||
"cang": ("c", "ang"),
|
||||
"cao": ("c", "ao"),
|
||||
"ce": ("c", "e"),
|
||||
"cen": ("c", "en"),
|
||||
"ceng": ("c", "eng"),
|
||||
"cha": ("ch", "a"),
|
||||
"chai": ("ch", "ai"),
|
||||
"chan": ("ch", "an"),
|
||||
"chang": ("ch", "ang"),
|
||||
"chao": ("ch", "ao"),
|
||||
"che": ("ch", "e"),
|
||||
"chen": ("ch", "en"),
|
||||
"cheng": ("ch", "eng"),
|
||||
"chi": ("ch", "iii"),
|
||||
"chong": ("ch", "ong"),
|
||||
"chou": ("ch", "ou"),
|
||||
"chu": ("ch", "u"),
|
||||
"chua": ("ch", "ua"),
|
||||
"chuai": ("ch", "uai"),
|
||||
"chuan": ("ch", "uan"),
|
||||
"chuang": ("ch", "uang"),
|
||||
"chui": ("ch", "uei"),
|
||||
"chun": ("ch", "uen"),
|
||||
"chuo": ("ch", "uo"),
|
||||
"ci": ("c", "ii"),
|
||||
"cong": ("c", "ong"),
|
||||
"cou": ("c", "ou"),
|
||||
"cu": ("c", "u"),
|
||||
"cuan": ("c", "uan"),
|
||||
"cui": ("c", "uei"),
|
||||
"cun": ("c", "uen"),
|
||||
"cuo": ("c", "uo"),
|
||||
"da": ("d", "a"),
|
||||
"dai": ("d", "ai"),
|
||||
"dan": ("d", "an"),
|
||||
"dang": ("d", "ang"),
|
||||
"dao": ("d", "ao"),
|
||||
"de": ("d", "e"),
|
||||
"dei": ("d", "ei"),
|
||||
"den": ("d", "en"),
|
||||
"deng": ("d", "eng"),
|
||||
"di": ("d", "i"),
|
||||
"dia": ("d", "ia"),
|
||||
"dian": ("d", "ian"),
|
||||
"diao": ("d", "iao"),
|
||||
"die": ("d", "ie"),
|
||||
"ding": ("d", "ing"),
|
||||
"diu": ("d", "iou"),
|
||||
"dong": ("d", "ong"),
|
||||
"dou": ("d", "ou"),
|
||||
"du": ("d", "u"),
|
||||
"duan": ("d", "uan"),
|
||||
"dui": ("d", "uei"),
|
||||
"dun": ("d", "uen"),
|
||||
"duo": ("d", "uo"),
|
||||
"e": ("^", "e"),
|
||||
"ei": ("^", "ei"),
|
||||
"en": ("^", "en"),
|
||||
"ng": ("^", "en"),
|
||||
"eng": ("^", "eng"),
|
||||
"er": ("^", "er"),
|
||||
"fa": ("f", "a"),
|
||||
"fan": ("f", "an"),
|
||||
"fang": ("f", "ang"),
|
||||
"fei": ("f", "ei"),
|
||||
"fen": ("f", "en"),
|
||||
"feng": ("f", "eng"),
|
||||
"fo": ("f", "o"),
|
||||
"fou": ("f", "ou"),
|
||||
"fu": ("f", "u"),
|
||||
"ga": ("g", "a"),
|
||||
"gai": ("g", "ai"),
|
||||
"gan": ("g", "an"),
|
||||
"gang": ("g", "ang"),
|
||||
"gao": ("g", "ao"),
|
||||
"ge": ("g", "e"),
|
||||
"gei": ("g", "ei"),
|
||||
"gen": ("g", "en"),
|
||||
"geng": ("g", "eng"),
|
||||
"gong": ("g", "ong"),
|
||||
"gou": ("g", "ou"),
|
||||
"gu": ("g", "u"),
|
||||
"gua": ("g", "ua"),
|
||||
"guai": ("g", "uai"),
|
||||
"guan": ("g", "uan"),
|
||||
"guang": ("g", "uang"),
|
||||
"gui": ("g", "uei"),
|
||||
"gun": ("g", "uen"),
|
||||
"guo": ("g", "uo"),
|
||||
"ha": ("h", "a"),
|
||||
"hai": ("h", "ai"),
|
||||
"han": ("h", "an"),
|
||||
"hang": ("h", "ang"),
|
||||
"hao": ("h", "ao"),
|
||||
"he": ("h", "e"),
|
||||
"hei": ("h", "ei"),
|
||||
"hen": ("h", "en"),
|
||||
"heng": ("h", "eng"),
|
||||
"hong": ("h", "ong"),
|
||||
"hou": ("h", "ou"),
|
||||
"hu": ("h", "u"),
|
||||
"hua": ("h", "ua"),
|
||||
"huai": ("h", "uai"),
|
||||
"huan": ("h", "uan"),
|
||||
"huang": ("h", "uang"),
|
||||
"hui": ("h", "uei"),
|
||||
"hun": ("h", "uen"),
|
||||
"huo": ("h", "uo"),
|
||||
"ji": ("j", "i"),
|
||||
"jia": ("j", "ia"),
|
||||
"jian": ("j", "ian"),
|
||||
"jiang": ("j", "iang"),
|
||||
"jiao": ("j", "iao"),
|
||||
"jie": ("j", "ie"),
|
||||
"jin": ("j", "in"),
|
||||
"jing": ("j", "ing"),
|
||||
"jiong": ("j", "iong"),
|
||||
"jiu": ("j", "iou"),
|
||||
"ju": ("j", "v"),
|
||||
"juan": ("j", "van"),
|
||||
"jue": ("j", "ve"),
|
||||
"jun": ("j", "vn"),
|
||||
"ka": ("k", "a"),
|
||||
"kai": ("k", "ai"),
|
||||
"kan": ("k", "an"),
|
||||
"kang": ("k", "ang"),
|
||||
"kao": ("k", "ao"),
|
||||
"ke": ("k", "e"),
|
||||
"kei": ("k", "ei"),
|
||||
"ken": ("k", "en"),
|
||||
"keng": ("k", "eng"),
|
||||
"kong": ("k", "ong"),
|
||||
"kou": ("k", "ou"),
|
||||
"ku": ("k", "u"),
|
||||
"kua": ("k", "ua"),
|
||||
"kuai": ("k", "uai"),
|
||||
"kuan": ("k", "uan"),
|
||||
"kuang": ("k", "uang"),
|
||||
"kui": ("k", "uei"),
|
||||
"kun": ("k", "uen"),
|
||||
"kuo": ("k", "uo"),
|
||||
"la": ("l", "a"),
|
||||
"lai": ("l", "ai"),
|
||||
"lan": ("l", "an"),
|
||||
"lang": ("l", "ang"),
|
||||
"lao": ("l", "ao"),
|
||||
"le": ("l", "e"),
|
||||
"lei": ("l", "ei"),
|
||||
"leng": ("l", "eng"),
|
||||
"li": ("l", "i"),
|
||||
"lia": ("l", "ia"),
|
||||
"lian": ("l", "ian"),
|
||||
"liang": ("l", "iang"),
|
||||
"liao": ("l", "iao"),
|
||||
"lie": ("l", "ie"),
|
||||
"lin": ("l", "in"),
|
||||
"ling": ("l", "ing"),
|
||||
"liu": ("l", "iou"),
|
||||
"lo": ("l", "o"),
|
||||
"long": ("l", "ong"),
|
||||
"lou": ("l", "ou"),
|
||||
"lu": ("l", "u"),
|
||||
"lv": ("l", "v"),
|
||||
"luan": ("l", "uan"),
|
||||
"lve": ("l", "ve"),
|
||||
"lue": ("l", "ve"),
|
||||
"lun": ("l", "uen"),
|
||||
"luo": ("l", "uo"),
|
||||
"ma": ("m", "a"),
|
||||
"mai": ("m", "ai"),
|
||||
"man": ("m", "an"),
|
||||
"mang": ("m", "ang"),
|
||||
"mao": ("m", "ao"),
|
||||
"me": ("m", "e"),
|
||||
"mei": ("m", "ei"),
|
||||
"men": ("m", "en"),
|
||||
"meng": ("m", "eng"),
|
||||
"mi": ("m", "i"),
|
||||
"mian": ("m", "ian"),
|
||||
"miao": ("m", "iao"),
|
||||
"mie": ("m", "ie"),
|
||||
"min": ("m", "in"),
|
||||
"ming": ("m", "ing"),
|
||||
"miu": ("m", "iou"),
|
||||
"mo": ("m", "o"),
|
||||
"mou": ("m", "ou"),
|
||||
"mu": ("m", "u"),
|
||||
"na": ("n", "a"),
|
||||
"nai": ("n", "ai"),
|
||||
"nan": ("n", "an"),
|
||||
"nang": ("n", "ang"),
|
||||
"nao": ("n", "ao"),
|
||||
"ne": ("n", "e"),
|
||||
"nei": ("n", "ei"),
|
||||
"nen": ("n", "en"),
|
||||
"neng": ("n", "eng"),
|
||||
"ni": ("n", "i"),
|
||||
"nia": ("n", "ia"),
|
||||
"nian": ("n", "ian"),
|
||||
"niang": ("n", "iang"),
|
||||
"niao": ("n", "iao"),
|
||||
"nie": ("n", "ie"),
|
||||
"nin": ("n", "in"),
|
||||
"ning": ("n", "ing"),
|
||||
"niu": ("n", "iou"),
|
||||
"nong": ("n", "ong"),
|
||||
"nou": ("n", "ou"),
|
||||
"nu": ("n", "u"),
|
||||
"nv": ("n", "v"),
|
||||
"nuan": ("n", "uan"),
|
||||
"nve": ("n", "ve"),
|
||||
"nue": ("n", "ve"),
|
||||
"nuo": ("n", "uo"),
|
||||
"o": ("^", "o"),
|
||||
"ou": ("^", "ou"),
|
||||
"pa": ("p", "a"),
|
||||
"pai": ("p", "ai"),
|
||||
"pan": ("p", "an"),
|
||||
"pang": ("p", "ang"),
|
||||
"pao": ("p", "ao"),
|
||||
"pe": ("p", "e"),
|
||||
"pei": ("p", "ei"),
|
||||
"pen": ("p", "en"),
|
||||
"peng": ("p", "eng"),
|
||||
"pi": ("p", "i"),
|
||||
"pian": ("p", "ian"),
|
||||
"piao": ("p", "iao"),
|
||||
"pie": ("p", "ie"),
|
||||
"pin": ("p", "in"),
|
||||
"ping": ("p", "ing"),
|
||||
"po": ("p", "o"),
|
||||
"pou": ("p", "ou"),
|
||||
"pu": ("p", "u"),
|
||||
"qi": ("q", "i"),
|
||||
"qia": ("q", "ia"),
|
||||
"qian": ("q", "ian"),
|
||||
"qiang": ("q", "iang"),
|
||||
"qiao": ("q", "iao"),
|
||||
"qie": ("q", "ie"),
|
||||
"qin": ("q", "in"),
|
||||
"qing": ("q", "ing"),
|
||||
"qiong": ("q", "iong"),
|
||||
"qiu": ("q", "iou"),
|
||||
"qu": ("q", "v"),
|
||||
"quan": ("q", "van"),
|
||||
"que": ("q", "ve"),
|
||||
"qun": ("q", "vn"),
|
||||
"ran": ("r", "an"),
|
||||
"rang": ("r", "ang"),
|
||||
"rao": ("r", "ao"),
|
||||
"re": ("r", "e"),
|
||||
"ren": ("r", "en"),
|
||||
"reng": ("r", "eng"),
|
||||
"ri": ("r", "iii"),
|
||||
"rong": ("r", "ong"),
|
||||
"rou": ("r", "ou"),
|
||||
"ru": ("r", "u"),
|
||||
"rua": ("r", "ua"),
|
||||
"ruan": ("r", "uan"),
|
||||
"rui": ("r", "uei"),
|
||||
"run": ("r", "uen"),
|
||||
"ruo": ("r", "uo"),
|
||||
"sa": ("s", "a"),
|
||||
"sai": ("s", "ai"),
|
||||
"san": ("s", "an"),
|
||||
"sang": ("s", "ang"),
|
||||
"sao": ("s", "ao"),
|
||||
"se": ("s", "e"),
|
||||
"sen": ("s", "en"),
|
||||
"seng": ("s", "eng"),
|
||||
"sha": ("sh", "a"),
|
||||
"shai": ("sh", "ai"),
|
||||
"shan": ("sh", "an"),
|
||||
"shang": ("sh", "ang"),
|
||||
"shao": ("sh", "ao"),
|
||||
"she": ("sh", "e"),
|
||||
"shei": ("sh", "ei"),
|
||||
"shen": ("sh", "en"),
|
||||
"sheng": ("sh", "eng"),
|
||||
"shi": ("sh", "iii"),
|
||||
"shou": ("sh", "ou"),
|
||||
"shu": ("sh", "u"),
|
||||
"shua": ("sh", "ua"),
|
||||
"shuai": ("sh", "uai"),
|
||||
"shuan": ("sh", "uan"),
|
||||
"shuang": ("sh", "uang"),
|
||||
"shui": ("sh", "uei"),
|
||||
"shun": ("sh", "uen"),
|
||||
"shuo": ("sh", "uo"),
|
||||
"si": ("s", "ii"),
|
||||
"song": ("s", "ong"),
|
||||
"sou": ("s", "ou"),
|
||||
"su": ("s", "u"),
|
||||
"suan": ("s", "uan"),
|
||||
"sui": ("s", "uei"),
|
||||
"sun": ("s", "uen"),
|
||||
"suo": ("s", "uo"),
|
||||
"ta": ("t", "a"),
|
||||
"tai": ("t", "ai"),
|
||||
"tan": ("t", "an"),
|
||||
"tang": ("t", "ang"),
|
||||
"tao": ("t", "ao"),
|
||||
"te": ("t", "e"),
|
||||
"tei": ("t", "ei"),
|
||||
"teng": ("t", "eng"),
|
||||
"ti": ("t", "i"),
|
||||
"tian": ("t", "ian"),
|
||||
"tiao": ("t", "iao"),
|
||||
"tie": ("t", "ie"),
|
||||
"ting": ("t", "ing"),
|
||||
"tong": ("t", "ong"),
|
||||
"tou": ("t", "ou"),
|
||||
"tu": ("t", "u"),
|
||||
"tuan": ("t", "uan"),
|
||||
"tui": ("t", "uei"),
|
||||
"tun": ("t", "uen"),
|
||||
"tuo": ("t", "uo"),
|
||||
"wa": ("^", "ua"),
|
||||
"wai": ("^", "uai"),
|
||||
"wan": ("^", "uan"),
|
||||
"wang": ("^", "uang"),
|
||||
"wei": ("^", "uei"),
|
||||
"wen": ("^", "uen"),
|
||||
"weng": ("^", "ueng"),
|
||||
"wo": ("^", "uo"),
|
||||
"wu": ("^", "u"),
|
||||
"xi": ("x", "i"),
|
||||
"xia": ("x", "ia"),
|
||||
"xian": ("x", "ian"),
|
||||
"xiang": ("x", "iang"),
|
||||
"xiao": ("x", "iao"),
|
||||
"xie": ("x", "ie"),
|
||||
"xin": ("x", "in"),
|
||||
"xing": ("x", "ing"),
|
||||
"xiong": ("x", "iong"),
|
||||
"xiu": ("x", "iou"),
|
||||
"xu": ("x", "v"),
|
||||
"xuan": ("x", "van"),
|
||||
"xue": ("x", "ve"),
|
||||
"xun": ("x", "vn"),
|
||||
"ya": ("^", "ia"),
|
||||
"yan": ("^", "ian"),
|
||||
"yang": ("^", "iang"),
|
||||
"yao": ("^", "iao"),
|
||||
"ye": ("^", "ie"),
|
||||
"yi": ("^", "i"),
|
||||
"yin": ("^", "in"),
|
||||
"ying": ("^", "ing"),
|
||||
"yo": ("^", "iou"),
|
||||
"yong": ("^", "iong"),
|
||||
"you": ("^", "iou"),
|
||||
"yu": ("^", "v"),
|
||||
"yuan": ("^", "van"),
|
||||
"yue": ("^", "ve"),
|
||||
"yun": ("^", "vn"),
|
||||
"za": ("z", "a"),
|
||||
"zai": ("z", "ai"),
|
||||
"zan": ("z", "an"),
|
||||
"zang": ("z", "ang"),
|
||||
"zao": ("z", "ao"),
|
||||
"ze": ("z", "e"),
|
||||
"zei": ("z", "ei"),
|
||||
"zen": ("z", "en"),
|
||||
"zeng": ("z", "eng"),
|
||||
"zha": ("zh", "a"),
|
||||
"zhai": ("zh", "ai"),
|
||||
"zhan": ("zh", "an"),
|
||||
"zhang": ("zh", "ang"),
|
||||
"zhao": ("zh", "ao"),
|
||||
"zhe": ("zh", "e"),
|
||||
"zhei": ("zh", "ei"),
|
||||
"zhen": ("zh", "en"),
|
||||
"zheng": ("zh", "eng"),
|
||||
"zhi": ("zh", "iii"),
|
||||
"zhong": ("zh", "ong"),
|
||||
"zhou": ("zh", "ou"),
|
||||
"zhu": ("zh", "u"),
|
||||
"zhua": ("zh", "ua"),
|
||||
"zhuai": ("zh", "uai"),
|
||||
"zhuan": ("zh", "uan"),
|
||||
"zhuang": ("zh", "uang"),
|
||||
"zhui": ("zh", "uei"),
|
||||
"zhun": ("zh", "uen"),
|
||||
"zhuo": ("zh", "uo"),
|
||||
"zi": ("z", "ii"),
|
||||
"zong": ("z", "ong"),
|
||||
"zou": ("z", "ou"),
|
||||
"zu": ("z", "u"),
|
||||
"zuan": ("z", "uan"),
|
||||
"zui": ("z", "uei"),
|
||||
"zun": ("z", "uen"),
|
||||
"zuo": ("z", "uo"),
|
||||
}
|
@ -1,53 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file generates the file that maps tokens to IDs.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from symbols import symbols
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=Path,
|
||||
default=Path("data/tokens.txt"),
|
||||
help="Path to the dict that maps the text tokens to IDs",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
tokens = Path(args.tokens)
|
||||
|
||||
with open(tokens, "w", encoding="utf-8") as f:
|
||||
for token_id, token in enumerate(symbols):
|
||||
f.write(f"{token} {token_id}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,59 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file reads the texts in given manifest and save the new cuts with tokens.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import CutSet, load_manifest
|
||||
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
|
||||
def prepare_tokens_baker_zh():
|
||||
output_dir = Path("data/spectrogram")
|
||||
prefix = "baker_zh"
|
||||
suffix = "jsonl.gz"
|
||||
partition = "all"
|
||||
|
||||
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||
|
||||
tokenizer = Tokenizer()
|
||||
|
||||
new_cuts = []
|
||||
i = 0
|
||||
for cut in cut_set:
|
||||
# Each cut only contains one supervision
|
||||
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
|
||||
text = cut.supervisions[0].normalized_text
|
||||
cut.tokens = tokenizer.text_to_tokens(text)
|
||||
|
||||
new_cuts.append(cut)
|
||||
|
||||
new_cut_set = CutSet.from_cuts(new_cuts)
|
||||
new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
prepare_tokens_baker_zh()
|
@ -1,328 +0,0 @@
|
||||
姐姐 jie3 jie
|
||||
宝宝 bao3 bao
|
||||
哥哥 ge1 ge
|
||||
妹妹 mei4 mei
|
||||
弟弟 di4 di
|
||||
妈妈 ma1 ma
|
||||
开心哦 kai1 xin1 o
|
||||
爸爸 ba4 ba
|
||||
秘密哟 mi4 mi4 yo
|
||||
哦 o
|
||||
一年 yi4 nian2
|
||||
一夜 yi2 ye4
|
||||
一切 yi2 qie4
|
||||
一座 yi2 zuo4
|
||||
一下 yi2 xia4
|
||||
上一山 shang4 yi2 shan1
|
||||
下一山 xia4 yi2 shan1
|
||||
休息 xiu1 xi2
|
||||
东西 dong1 xi
|
||||
上一届 shang4 yi2 jie4
|
||||
便宜 pian2 yi4
|
||||
加长 jia1 chang2
|
||||
单田芳 shan4 tian2 fang1
|
||||
帧 zhen1
|
||||
长时间 chang2 shi2 jian1
|
||||
长时 chang2 shi2
|
||||
识别 shi2 bie2
|
||||
生命中 sheng1 ming4 zhong1
|
||||
踏实 ta1 shi
|
||||
嗯 en4
|
||||
溜达 liu1 da
|
||||
少儿 shao4 er2
|
||||
爷爷 ye2 ye
|
||||
不是 bu2 shi4
|
||||
一圈 yi1 quan1
|
||||
厜读一声 zui1 du2 yi4 sheng1
|
||||
一种 yi4 zhong3
|
||||
一簇簇 yi2 cu4 cu4
|
||||
一个 yi2 ge4
|
||||
一样 yi2 yang4
|
||||
一跩一跩 yi4 zhuai3 yi4 zhuai3
|
||||
一会儿 yi2 hui4 er
|
||||
一幢 yi2 zhuang4
|
||||
挨了 ai2 le
|
||||
熬菜 ao1 cai4
|
||||
扒鸡 pa2 ji1
|
||||
背枪 bei1 qiang1
|
||||
绷瓷儿 beng4 ci2 er2
|
||||
绷劲儿 beng3 jin4 er
|
||||
绷着脸 beng3 zhe lian3
|
||||
藏医 zang4 yi1
|
||||
噌吰 cheng1 hong2
|
||||
差点儿 cha4 dian3 er
|
||||
差失 cha1 shi1
|
||||
差误 cha1 wu4
|
||||
孱头 can4 tou
|
||||
乘间 cheng2 jian4
|
||||
锄镰棘矜 chu2 lian2 ji2 qin2
|
||||
川藏 chuan1 zang4
|
||||
穿著 chuan1 zhuo2
|
||||
答讪 da1 shan4
|
||||
答言 da1 yan2
|
||||
大伯子 da4 bai3 zi
|
||||
大夫 dai4 fu
|
||||
弹冠 tan2 guan1
|
||||
当间 dang1 jian4
|
||||
当然咯 dang1 ran2 lo
|
||||
点种 dian3 zhong3
|
||||
垛好 duo4 hao3
|
||||
发疟子 fa1 yao4 zi
|
||||
饭熟了 fan4 shou2 le
|
||||
附著 fu4 zhuo2
|
||||
复沓 fu4 ta4
|
||||
供稿 gong1 gao3
|
||||
供养 gong1 yang3
|
||||
骨朵 gu1 duo
|
||||
骨碌 gu1 lu
|
||||
果脯 guo3 fu3
|
||||
哈什玛 ha4 shi2 ma3
|
||||
海蜇 hai3 zhe2
|
||||
呵欠 he1 qian
|
||||
河水汤汤 he2 shui3 shang1 shang1
|
||||
鹄立 hu2 li4
|
||||
鹄望 hu2 wang4
|
||||
混人 hun2 ren2
|
||||
混水 hun2 shui3
|
||||
鸡血 ji1 xie3
|
||||
缉鞋口 qi1 xie2 kou3
|
||||
亟来闻讯 qi4 lai2 wen2 xun4
|
||||
计量 ji4 liang2
|
||||
济水 ji3 shui3
|
||||
间杂 jian4 za2
|
||||
脚跐两只船 jiao3 ci3 liang3 zhi1 chuan2
|
||||
脚儿 jue2 er2
|
||||
口角 kou3 jiao3
|
||||
勒石 le4 shi2
|
||||
累进 lei3 jin4
|
||||
累累如丧家之犬 lei2 lei2 ru2 sang4 jia1 zhi1 quan3
|
||||
累年 lei3 nian2
|
||||
脸涨通红 lian3 zhang4 tong1 hong2
|
||||
踉锵 liang4 qiang1
|
||||
燎眉毛 liao3 mei2 mao2
|
||||
燎头发 liao3 tou2 fa4
|
||||
溜达 liu1 da
|
||||
溜缝儿 liu4 feng4 er
|
||||
馏口饭 liu4 kou3 fan4
|
||||
遛马 liu4 ma3
|
||||
遛鸟 liu4 niao3
|
||||
遛弯儿 liu4 wan1 er
|
||||
楼枪机 lou1 qiang1 ji1
|
||||
搂钱 lou1 qian2
|
||||
鹿脯 lu4 fu3
|
||||
露头 lou4 tou2
|
||||
落魄 luo4 po4
|
||||
捋胡子 lv3 hu2 zi
|
||||
绿地 lv4 di4
|
||||
麦垛 mai4 duo4
|
||||
没劲儿 mei2 jin4 er
|
||||
闷棍 men4 gun4
|
||||
闷葫芦 men4 hu2 lu
|
||||
闷头干 men1 tou2 gan4
|
||||
蒙古 meng3 gu3
|
||||
靡日不思 mi3 ri4 bu4 si1
|
||||
缪姓 miao4 xing4
|
||||
抹墙 mo4 qiang2
|
||||
抹下脸 ma1 xia4 lian3
|
||||
泥子 ni4 zi
|
||||
拗不过 niu4 bu guo4
|
||||
排车 pai3 che1
|
||||
盘诘 pan2 jie2
|
||||
膀肿 pang1 zhong3
|
||||
炮干 bao1 gan1
|
||||
炮格 pao2 ge2
|
||||
碰钉子 peng4 ding1 zi
|
||||
缥色 piao3 se4
|
||||
瀑河 bao4 he2
|
||||
蹊径 xi1 jing4
|
||||
前后相属 qian2 hou4 xiang1 zhu3
|
||||
翘尾巴 qiao4 wei3 ba
|
||||
趄坡儿 qie4 po1 er
|
||||
秦桧 qin2 hui4
|
||||
圈马 juan1 ma3
|
||||
雀盲眼 qiao3 mang2 yan3
|
||||
雀子 qiao1 zi
|
||||
三年五载 san1 nian2 wu3 zai3
|
||||
加载 jia1 zai3
|
||||
山大王 shan1 dai4 wang
|
||||
苫屋草 shan4 wu1 cao3
|
||||
数数 shu3 shu4
|
||||
说客 shui4 ke4
|
||||
思量 si1 liang2
|
||||
伺侯 ci4 hou
|
||||
踏实 ta1 shi
|
||||
提溜 di1 liu
|
||||
调拨 diao4 bo1
|
||||
帖子 tie3 zi
|
||||
铜钿 tong2 tian2
|
||||
头昏脑涨 tou2 hun1 nao3 zhang4
|
||||
褪色 tui4 se4
|
||||
褪着手 tun4 zhe shou3
|
||||
圩子 wei2 zi
|
||||
尾巴 wei3 ba
|
||||
系好船只 xi4 hao3 chuan2 zhi1
|
||||
系好马匹 xi4 hao3 ma3 pi3
|
||||
杏脯 xing4 fu3
|
||||
姓单 xing4 shan4
|
||||
姓葛 xing4 ge3
|
||||
姓哈 xing4 ha3
|
||||
姓解 xing4 xie4
|
||||
姓秘 xing4 bi4
|
||||
姓宁 xing4 ning4
|
||||
旋风 xuan4 feng1
|
||||
旋根车轴 xuan4 gen1 che1 zhou2
|
||||
荨麻 qian2 ma2
|
||||
一幢楼房 yi1 zhuang4 lou2 fang2
|
||||
遗之千金 wei4 zhi1 qian1 jin1
|
||||
殷殷 yin3 yin3
|
||||
应招 ying4 zhao1
|
||||
用称约 yong4 cheng4 yao1
|
||||
约斤肉 yao1 jin1 rou4
|
||||
晕机 yun4 ji1
|
||||
熨贴 yu4 tie1
|
||||
咋办 za3 ban4
|
||||
咋呼 zha1 hu
|
||||
仔兽 zi3 shou4
|
||||
扎彩 za1 cai3
|
||||
扎实 zha1 shi
|
||||
扎腰带 za1 yao1 dai4
|
||||
轧朋友 ga2 peng2 you3
|
||||
爪子 zhua3 zi
|
||||
折腾 zhe1 teng
|
||||
着实 zhuo2 shi2
|
||||
着我旧时裳 zhuo2 wo3 jiu4 shi2 chang2
|
||||
枝蔓 zhi1 man4
|
||||
中鹄 zhong1 hu2
|
||||
中选 zhong4 xuan3
|
||||
猪圈 zhu1 juan4
|
||||
拽住不放 zhuai4 zhu4 bu4 fang4
|
||||
转悠 zhuan4 you
|
||||
庄稼熟了 zhuang1 jia shou2 le
|
||||
酌量 zhuo2 liang2
|
||||
罪行累累 zui4 xing2 lei3 lei3
|
||||
一手 yi4 shou3
|
||||
一去不复返 yi2 qu4 bu2 fu4 fan3
|
||||
一颗 yi4 ke1
|
||||
一件 yi2 jian4
|
||||
一斤 yi4 jin1
|
||||
一点 yi4 dian3
|
||||
一朵 yi4 duo3
|
||||
一声 yi4 sheng1
|
||||
一身 yi4 shen1
|
||||
不要 bu2 yao4
|
||||
一人 yi4 ren2
|
||||
一个 yi2 ge4
|
||||
一把 yi4 ba3
|
||||
一门 yi4 men2
|
||||
一門 yi4 men2
|
||||
一艘 yi4 sou1
|
||||
一片 yi2 pian4
|
||||
一篇 yi2 pian1
|
||||
一份 yi2 fen4
|
||||
好嗲 hao3 dia3
|
||||
随地 sui2 di4
|
||||
扁担长 bian3 dan4 chang3
|
||||
一堆 yi4 dui1
|
||||
不义 bu2 yi4
|
||||
放一放 fang4 yi2 fang4
|
||||
一米 yi4 mi3
|
||||
一顿 yi2 dun4
|
||||
一层楼 yi4 ceng2 lou2
|
||||
一条 yi4 tiao2
|
||||
一件 yi2 jian4
|
||||
一棵 yi4 ke1
|
||||
一小股 yi4 xiao3 gu3
|
||||
一拐一拐 yi4 guai3 yi4 guai3
|
||||
一根 yi4 gen1
|
||||
沆瀣一气 hang4 xie4 yi2 qi4
|
||||
一丝 yi4 si1
|
||||
一毫 yi4 hao2
|
||||
一樣 yi2 yang4
|
||||
处处 chu4 chu4
|
||||
一餐 yi4 can
|
||||
永不 yong3 bu2
|
||||
一看 yi2 kan4
|
||||
一架 yi2 jia4
|
||||
送还 song4 huan2
|
||||
一见 yi2 jian4
|
||||
一座 yi2 zuo4
|
||||
一块 yi2 kuai4
|
||||
一天 yi4 tian1
|
||||
一只 yi4 zhi1
|
||||
一支 yi4 zhi1
|
||||
一字 yi2 zi4
|
||||
一句 yi2 ju4
|
||||
一张 yi4 zhang1
|
||||
一條 yi4 tiao2
|
||||
一场 yi4 chang3
|
||||
一粒 yi2 li4
|
||||
小俩口 xiao3 liang3 kou3
|
||||
一首 yi4 shou3
|
||||
一对 yi2 dui4
|
||||
一手 yi4 shou3
|
||||
又一村 you4 yi4 cun1
|
||||
一概而论 yi2 gai4 er2 lun4
|
||||
一峰峰 yi4 feng1 feng1
|
||||
不但 bu2 dan4
|
||||
一笑 yi2 xiao4
|
||||
挠痒痒 nao2 yang3 yang
|
||||
不对 bu2 dui4
|
||||
拧开 ning3 kai1
|
||||
爱不释手 ai4 bu2 shi4 shou3
|
||||
一念 yi2 nian4
|
||||
夺得 duo2 de2
|
||||
一袭 yi4 xi2
|
||||
一定 yi2 ding4
|
||||
不慎 bu2 shen4
|
||||
剽窃 piao2 qie4
|
||||
一时 yi4 shi2
|
||||
撇开 pie3 kai1
|
||||
一祭 yi2 ji4
|
||||
发卡 fa4 qia3
|
||||
少不了 shao3 bu4 liao3
|
||||
千虑一失 qian1 lv4 yi4 shi1
|
||||
呛得 qiang4 de2
|
||||
切菜 qie1 cai4
|
||||
茄盒 qie2 he2
|
||||
不去 bu2 qu4
|
||||
一大圈 yi2 da4 quan1
|
||||
不再 bu2 zai4
|
||||
一群 yi4 qun2
|
||||
不必 bu2 bi4
|
||||
一些 yi4 xie1
|
||||
一路 yi2 lu4
|
||||
一股 yi4 gu3
|
||||
一到 yi2 dao4
|
||||
一拨 yi4 bo1
|
||||
一排 yi4 pai2
|
||||
一空 yi4 kong1
|
||||
吮吸着 shun3 xi1 zhe
|
||||
不适合 bu2 shi4 he2
|
||||
一串串 yi2 chuan4 chuan4
|
||||
一提起 yi4 ti2 qi3
|
||||
一尘不染 yi4 chen2 bu4 ran3
|
||||
一生 yi4 sheng1
|
||||
一派 yi2 pai4
|
||||
不断 bu2 duan4
|
||||
一次 yi2 ci4
|
||||
不进步 bu2 jin4 bu4
|
||||
娃娃 wa2 wa
|
||||
万户侯 wan4 hu4 hou2
|
||||
一方 yi4 fang1
|
||||
一番话 yi4 fan1 hua4
|
||||
一遍 yi2 bian4
|
||||
不计较 bu2 ji4 jiao4
|
||||
诇 xiong4
|
||||
一边 yi4 bian1
|
||||
一束 yi2 shu4
|
||||
一听到 yi4 ting1 dao4
|
||||
炸鸡 zha2 ji1
|
||||
乍暧还寒 zha4 ai4 huan2 han2
|
||||
我说诶 wo3 shuo1 ei1
|
||||
棒诶 bang4 ei1
|
||||
寒碜 han2 chen4
|
||||
应采儿 ying4 cai3 er2
|
||||
晕车 yun1 che1
|
||||
必应 bi4 ying4
|
||||
应援 ying4 yuan2
|
||||
应力 ying4 li4
|
@ -1,73 +0,0 @@
|
||||
# This file is copied from
|
||||
# https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py
|
||||
_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
|
||||
|
||||
_initials = [
|
||||
"^",
|
||||
"b",
|
||||
"c",
|
||||
"ch",
|
||||
"d",
|
||||
"f",
|
||||
"g",
|
||||
"h",
|
||||
"j",
|
||||
"k",
|
||||
"l",
|
||||
"m",
|
||||
"n",
|
||||
"p",
|
||||
"q",
|
||||
"r",
|
||||
"s",
|
||||
"sh",
|
||||
"t",
|
||||
"x",
|
||||
"z",
|
||||
"zh",
|
||||
]
|
||||
|
||||
_tones = ["1", "2", "3", "4", "5"]
|
||||
|
||||
_finals = [
|
||||
"a",
|
||||
"ai",
|
||||
"an",
|
||||
"ang",
|
||||
"ao",
|
||||
"e",
|
||||
"ei",
|
||||
"en",
|
||||
"eng",
|
||||
"er",
|
||||
"i",
|
||||
"ia",
|
||||
"ian",
|
||||
"iang",
|
||||
"iao",
|
||||
"ie",
|
||||
"ii",
|
||||
"iii",
|
||||
"in",
|
||||
"ing",
|
||||
"iong",
|
||||
"iou",
|
||||
"o",
|
||||
"ong",
|
||||
"ou",
|
||||
"u",
|
||||
"ua",
|
||||
"uai",
|
||||
"uan",
|
||||
"uang",
|
||||
"uei",
|
||||
"uen",
|
||||
"ueng",
|
||||
"uo",
|
||||
"v",
|
||||
"van",
|
||||
"ve",
|
||||
"vn",
|
||||
]
|
||||
|
||||
symbols = _pause + _initials + [i + j for i in _finals for j in _tones]
|
@ -1,137 +0,0 @@
|
||||
# This file is modified from
|
||||
# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# Note pinyin_dict is from ./pinyin_dict.py
|
||||
from pinyin_dict import pinyin_dict
|
||||
from pypinyin import Style
|
||||
from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin
|
||||
from pypinyin.converter import DefaultConverter
|
||||
from pypinyin.core import Pinyin, load_phrases_dict
|
||||
|
||||
|
||||
class _MyConverter(NeutralToneWith5Mixin, DefaultConverter):
|
||||
pass
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, tokens: str = ""):
|
||||
self._load_pinyin_dict()
|
||||
self._pinyin_parser = Pinyin(_MyConverter())
|
||||
|
||||
if tokens != "":
|
||||
self._load_tokens(tokens)
|
||||
|
||||
def texts_to_token_ids(self, texts: List[str], **kwargs) -> List[List[int]]:
|
||||
"""
|
||||
Args:
|
||||
texts:
|
||||
A list of sentences.
|
||||
kwargs:
|
||||
Not used. It is for compatibility with other TTS recipes in icefall.
|
||||
"""
|
||||
tokens = []
|
||||
|
||||
for text in texts:
|
||||
tokens.append(self.text_to_tokens(text))
|
||||
|
||||
return self.tokens_to_token_ids(tokens)
|
||||
|
||||
def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]:
|
||||
ans = []
|
||||
|
||||
for token_list in tokens:
|
||||
token_ids = []
|
||||
for t in token_list:
|
||||
if t not in self.token2id:
|
||||
logging.warning(f"Skip OOV {t}")
|
||||
continue
|
||||
token_ids.append(self.token2id[t])
|
||||
ans.append(token_ids)
|
||||
|
||||
return ans
|
||||
|
||||
def text_to_tokens(self, text: str) -> List[str]:
|
||||
# Convert "," to ["sp", "sil"]
|
||||
# Convert "。" to ["sil"]
|
||||
# append ["eos"] at the end of a sentence
|
||||
phonemes = ["sil"]
|
||||
pinyins = self._pinyin_parser.pinyin(
|
||||
text,
|
||||
style=Style.TONE3,
|
||||
errors=lambda x: [[w] for w in x],
|
||||
)
|
||||
|
||||
new_pinyin = []
|
||||
for p in pinyins:
|
||||
p = p[0]
|
||||
if p == ",":
|
||||
new_pinyin.extend(["sp", "sil"])
|
||||
elif p == "。":
|
||||
new_pinyin.append("sil")
|
||||
else:
|
||||
new_pinyin.append(p)
|
||||
sub_phonemes = self._get_phoneme4pinyin(new_pinyin)
|
||||
sub_phonemes.append("eos")
|
||||
phonemes.extend(sub_phonemes)
|
||||
return phonemes
|
||||
|
||||
def _get_phoneme4pinyin(self, pinyins):
|
||||
result = []
|
||||
for pinyin in pinyins:
|
||||
if pinyin in ("sil", "sp"):
|
||||
result.append(pinyin)
|
||||
elif pinyin[:-1] in pinyin_dict:
|
||||
tone = pinyin[-1]
|
||||
a = pinyin[:-1]
|
||||
a1, a2 = pinyin_dict[a]
|
||||
# every word is appended with a #0
|
||||
result += [a1, a2 + tone, "#0"]
|
||||
|
||||
return result
|
||||
|
||||
def _load_pinyin_dict(self):
|
||||
this_dir = Path(__file__).parent.resolve()
|
||||
my_dict = {}
|
||||
with open(f"{this_dir}/pypinyin-local.dict", "r", encoding="utf-8") as f:
|
||||
content = f.readlines()
|
||||
for line in content:
|
||||
cuts = line.strip().split()
|
||||
hanzi = cuts[0]
|
||||
pinyin = cuts[1:]
|
||||
my_dict[hanzi] = [[p] for p in pinyin]
|
||||
|
||||
load_phrases_dict(my_dict)
|
||||
|
||||
def _load_tokens(self, filename):
|
||||
token2id: Dict[str, int] = {}
|
||||
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
info = line.rstrip().split()
|
||||
if len(info) == 1:
|
||||
# case of space
|
||||
token = " "
|
||||
idx = int(info[0])
|
||||
else:
|
||||
token, idx = info[0], int(info[1])
|
||||
|
||||
assert token not in token2id, token
|
||||
|
||||
token2id[token] = idx
|
||||
|
||||
self.token2id = token2id
|
||||
self.vocab_size = len(self.token2id)
|
||||
self.pad_id = self.token2id["#0"]
|
||||
|
||||
|
||||
def main():
|
||||
tokenizer = Tokenizer()
|
||||
tokenizer._sentence_to_ids("你好,好的。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/local/validate_manifest.py
|
@ -1,124 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: build monotonic_align lib"
|
||||
if [ ! -d vits/monotonic_align/build ]; then
|
||||
cd vits/monotonic_align
|
||||
python3 setup.py build_ext --inplace
|
||||
cd ../../
|
||||
else
|
||||
log "monotonic_align lib already built"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Download data"
|
||||
|
||||
# The directory $dl_dir/BZNSYP will contain 3 sub directories:
|
||||
# - PhoneLabeling
|
||||
# - ProsodyLabeling
|
||||
# - Wave
|
||||
|
||||
# If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/BZNSYP $dl_dir/
|
||||
# touch $dl_dir/BZNSYP/.completed
|
||||
#
|
||||
if [ ! -d $dl_dir/BZNSYP ]; then
|
||||
lhotse download baker-zh $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare baker-zh manifest"
|
||||
# We assume that you have downloaded the baker corpus
|
||||
# to $dl_dir/BZNSYP
|
||||
mkdir -p data/manifests
|
||||
if [ ! -e data/manifests/.baker.done ]; then
|
||||
lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests
|
||||
touch data/manifests/.baker.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute spectrogram for baker (may take 3 minutes)"
|
||||
mkdir -p data/spectrogram
|
||||
if [ ! -e data/spectrogram/.baker.done ]; then
|
||||
./local/compute_spectrogram_baker.py
|
||||
touch data/spectrogram/.baker.done
|
||||
fi
|
||||
|
||||
if [ ! -e data/spectrogram/.baker-validated.done ]; then
|
||||
log "Validating data/spectrogram for baker"
|
||||
python3 ./local/validate_manifest.py \
|
||||
data/spectrogram/baker_zh_cuts_all.jsonl.gz
|
||||
touch data/spectrogram/.baker-validated.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Prepare tokens for baker-zh (may take 20 seconds)"
|
||||
if [ ! -e data/spectrogram/.baker_zh_with_token.done ]; then
|
||||
|
||||
./local/prepare_tokens_baker_zh.py
|
||||
|
||||
mv -v data/spectrogram/baker_zh_cuts_with_tokens_all.jsonl.gz \
|
||||
data/spectrogram/baker_zh_cuts_all.jsonl.gz
|
||||
|
||||
touch data/spectrogram/.baker_zh_with_token.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Split the baker-zh cuts into train, valid and test sets (may take 25 seconds)"
|
||||
if [ ! -e data/spectrogram/.baker_zh_split.done ]; then
|
||||
lhotse subset --last 600 \
|
||||
data/spectrogram/baker_zh_cuts_all.jsonl.gz \
|
||||
data/spectrogram/baker_zh_cuts_validtest.jsonl.gz
|
||||
lhotse subset --first 100 \
|
||||
data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \
|
||||
data/spectrogram/baker_zh_cuts_valid.jsonl.gz
|
||||
lhotse subset --last 500 \
|
||||
data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \
|
||||
data/spectrogram/baker_zh_cuts_test.jsonl.gz
|
||||
|
||||
rm data/spectrogram/baker_zh_cuts_validtest.jsonl.gz
|
||||
|
||||
n=$(( $(gunzip -c data/spectrogram/baker_zh_cuts_all.jsonl.gz | wc -l) - 600 ))
|
||||
lhotse subset --first $n \
|
||||
data/spectrogram/baker_zh_cuts_all.jsonl.gz \
|
||||
data/spectrogram/baker_zh_cuts_train.jsonl.gz
|
||||
touch data/spectrogram/.baker_zh_split.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Generate token file"
|
||||
if [ ! -e data/tokens.txt ]; then
|
||||
./local/prepare_token_file.py --tokens data/tokens.txt
|
||||
fi
|
||||
fi
|
@ -1 +0,0 @@
|
||||
../../../icefall/shared
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/duration_predictor.py
|
@ -1,414 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script exports a VITS model from PyTorch to ONNX.
|
||||
|
||||
Export the model to ONNX:
|
||||
./vits/export-onnx.py \
|
||||
--epoch 1000 \
|
||||
--exp-dir vits/exp \
|
||||
--tokens data/tokens.txt
|
||||
|
||||
It will generate one file inside vits/exp:
|
||||
- vits-epoch-1000.onnx
|
||||
|
||||
See ./test_onnx.py for how to use the exported ONNX models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="vits/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/tokens.txt",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="high",
|
||||
choices=["low", "medium", "high"],
|
||||
help="""If not empty, valid values are: low, medium, high.
|
||||
It controls the model size. low -> runs faster.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class OnnxModel(nn.Module):
|
||||
"""A wrapper for VITS generator."""
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
"""
|
||||
Args:
|
||||
model:
|
||||
A VITS generator.
|
||||
frame_shift:
|
||||
The frame shift in samples.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
tokens_lens: torch.Tensor,
|
||||
noise_scale: float = 0.667,
|
||||
alpha: float = 1.0,
|
||||
noise_scale_dur: float = 0.8,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Please see the help information of VITS.inference_batch
|
||||
|
||||
Args:
|
||||
tokens:
|
||||
Input text token indexes (1, T_text)
|
||||
tokens_lens:
|
||||
Number of tokens of shape (1,)
|
||||
noise_scale (float):
|
||||
Noise scale parameter for flow.
|
||||
noise_scale_dur (float):
|
||||
Noise scale parameter for duration predictor.
|
||||
alpha (float):
|
||||
Alpha parameter to control the speed of generated speech.
|
||||
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- audio, generated wavform tensor, (B, T_wav)
|
||||
"""
|
||||
audio, _, _ = self.model.generator.inference(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
noise_scale=noise_scale,
|
||||
noise_scale_dur=noise_scale_dur,
|
||||
alpha=alpha,
|
||||
)
|
||||
return audio
|
||||
|
||||
|
||||
def export_model_onnx(
|
||||
model: nn.Module,
|
||||
model_filename: str,
|
||||
vocab_size: int,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the given generator model to ONNX format.
|
||||
The exported model has one input:
|
||||
|
||||
- tokens, a tensor of shape (1, T_text); dtype is torch.int64
|
||||
|
||||
and it has one output:
|
||||
|
||||
- audio, a tensor of shape (1, T'); dtype is torch.float32
|
||||
|
||||
Args:
|
||||
model:
|
||||
The VITS generator.
|
||||
model_filename:
|
||||
The filename to save the exported ONNX model.
|
||||
vocab_size:
|
||||
Number of tokens used in training.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
|
||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
|
||||
noise_scale = torch.tensor([1], dtype=torch.float32)
|
||||
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
||||
alpha = torch.tensor([1], dtype=torch.float32)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(tokens, tokens_lens, noise_scale, alpha, noise_scale_dur),
|
||||
model_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"tokens",
|
||||
"tokens_lens",
|
||||
"noise_scale",
|
||||
"alpha",
|
||||
"noise_scale_dur",
|
||||
],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"tokens": {0: "N", 1: "T"},
|
||||
"tokens_lens": {0: "N"},
|
||||
"audio": {0: "N", 1: "T"},
|
||||
},
|
||||
)
|
||||
|
||||
if model.model.spks is None:
|
||||
num_speakers = 1
|
||||
else:
|
||||
num_speakers = model.model.spks
|
||||
|
||||
meta_data = {
|
||||
"model_type": "vits",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "icefall", # must be icefall for models from icefall
|
||||
"language": "Chinese",
|
||||
"n_speakers": num_speakers,
|
||||
"sample_rate": model.model.sampling_rate, # Must match the real sample rate
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
add_meta_data(filename=model_filename, meta_data=meta_data)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
model = OnnxModel(model=model)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
|
||||
|
||||
suffix = f"epoch-{params.epoch}"
|
||||
|
||||
opset_version = 13
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
model_filename = params.exp_dir / f"vits-{suffix}.onnx"
|
||||
export_model_onnx(
|
||||
model,
|
||||
model_filename,
|
||||
params.vocab_size,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
logging.info(f"Exported generator to {model_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
||||
"""
|
||||
Supported languages.
|
||||
|
||||
LJSpeech is using "en-us" from the second column.
|
||||
|
||||
Pty Language Age/Gender VoiceName File Other Languages
|
||||
5 af --/M Afrikaans gmw/af
|
||||
5 am --/M Amharic sem/am
|
||||
5 an --/M Aragonese roa/an
|
||||
5 ar --/M Arabic sem/ar
|
||||
5 as --/M Assamese inc/as
|
||||
5 az --/M Azerbaijani trk/az
|
||||
5 ba --/M Bashkir trk/ba
|
||||
5 be --/M Belarusian zle/be
|
||||
5 bg --/M Bulgarian zls/bg
|
||||
5 bn --/M Bengali inc/bn
|
||||
5 bpy --/M Bishnupriya_Manipuri inc/bpy
|
||||
5 bs --/M Bosnian zls/bs
|
||||
5 ca --/M Catalan roa/ca
|
||||
5 chr-US-Qaaa-x-west --/M Cherokee_ iro/chr
|
||||
5 cmn --/M Chinese_(Mandarin,_latin_as_English) sit/cmn (zh-cmn 5)(zh 5)
|
||||
5 cmn-latn-pinyin --/M Chinese_(Mandarin,_latin_as_Pinyin) sit/cmn-Latn-pinyin (zh-cmn 5)(zh 5)
|
||||
5 cs --/M Czech zlw/cs
|
||||
5 cv --/M Chuvash trk/cv
|
||||
5 cy --/M Welsh cel/cy
|
||||
5 da --/M Danish gmq/da
|
||||
5 de --/M German gmw/de
|
||||
5 el --/M Greek grk/el
|
||||
5 en-029 --/M English_(Caribbean) gmw/en-029 (en 10)
|
||||
2 en-gb --/M English_(Great_Britain) gmw/en (en 2)
|
||||
5 en-gb-scotland --/M English_(Scotland) gmw/en-GB-scotland (en 4)
|
||||
5 en-gb-x-gbclan --/M English_(Lancaster) gmw/en-GB-x-gbclan (en-gb 3)(en 5)
|
||||
5 en-gb-x-gbcwmd --/M English_(West_Midlands) gmw/en-GB-x-gbcwmd (en-gb 9)(en 9)
|
||||
5 en-gb-x-rp --/M English_(Received_Pronunciation) gmw/en-GB-x-rp (en-gb 4)(en 5)
|
||||
2 en-us --/M English_(America) gmw/en-US (en 3)
|
||||
5 en-us-nyc --/M English_(America,_New_York_City) gmw/en-US-nyc
|
||||
5 eo --/M Esperanto art/eo
|
||||
5 es --/M Spanish_(Spain) roa/es
|
||||
5 es-419 --/M Spanish_(Latin_America) roa/es-419 (es-mx 6)
|
||||
5 et --/M Estonian urj/et
|
||||
5 eu --/M Basque eu
|
||||
5 fa --/M Persian ira/fa
|
||||
5 fa-latn --/M Persian_(Pinglish) ira/fa-Latn
|
||||
5 fi --/M Finnish urj/fi
|
||||
5 fr-be --/M French_(Belgium) roa/fr-BE (fr 8)
|
||||
5 fr-ch --/M French_(Switzerland) roa/fr-CH (fr 8)
|
||||
5 fr-fr --/M French_(France) roa/fr (fr 5)
|
||||
5 ga --/M Gaelic_(Irish) cel/ga
|
||||
5 gd --/M Gaelic_(Scottish) cel/gd
|
||||
5 gn --/M Guarani sai/gn
|
||||
5 grc --/M Greek_(Ancient) grk/grc
|
||||
5 gu --/M Gujarati inc/gu
|
||||
5 hak --/M Hakka_Chinese sit/hak
|
||||
5 haw --/M Hawaiian map/haw
|
||||
5 he --/M Hebrew sem/he
|
||||
5 hi --/M Hindi inc/hi
|
||||
5 hr --/M Croatian zls/hr (hbs 5)
|
||||
5 ht --/M Haitian_Creole roa/ht
|
||||
5 hu --/M Hungarian urj/hu
|
||||
5 hy --/M Armenian_(East_Armenia) ine/hy (hy-arevela 5)
|
||||
5 hyw --/M Armenian_(West_Armenia) ine/hyw (hy-arevmda 5)(hy 8)
|
||||
5 ia --/M Interlingua art/ia
|
||||
5 id --/M Indonesian poz/id
|
||||
5 io --/M Ido art/io
|
||||
5 is --/M Icelandic gmq/is
|
||||
5 it --/M Italian roa/it
|
||||
5 ja --/M Japanese jpx/ja
|
||||
5 jbo --/M Lojban art/jbo
|
||||
5 ka --/M Georgian ccs/ka
|
||||
5 kk --/M Kazakh trk/kk
|
||||
5 kl --/M Greenlandic esx/kl
|
||||
5 kn --/M Kannada dra/kn
|
||||
5 ko --/M Korean ko
|
||||
5 kok --/M Konkani inc/kok
|
||||
5 ku --/M Kurdish ira/ku
|
||||
5 ky --/M Kyrgyz trk/ky
|
||||
5 la --/M Latin itc/la
|
||||
5 lb --/M Luxembourgish gmw/lb
|
||||
5 lfn --/M Lingua_Franca_Nova art/lfn
|
||||
5 lt --/M Lithuanian bat/lt
|
||||
5 ltg --/M Latgalian bat/ltg
|
||||
5 lv --/M Latvian bat/lv
|
||||
5 mi --/M Māori poz/mi
|
||||
5 mk --/M Macedonian zls/mk
|
||||
5 ml --/M Malayalam dra/ml
|
||||
5 mr --/M Marathi inc/mr
|
||||
5 ms --/M Malay poz/ms
|
||||
5 mt --/M Maltese sem/mt
|
||||
5 mto --/M Totontepec_Mixe miz/mto
|
||||
5 my --/M Myanmar_(Burmese) sit/my
|
||||
5 nb --/M Norwegian_Bokmål gmq/nb (no 5)
|
||||
5 nci --/M Nahuatl_(Classical) azc/nci
|
||||
5 ne --/M Nepali inc/ne
|
||||
5 nl --/M Dutch gmw/nl
|
||||
5 nog --/M Nogai trk/nog
|
||||
5 om --/M Oromo cus/om
|
||||
5 or --/M Oriya inc/or
|
||||
5 pa --/M Punjabi inc/pa
|
||||
5 pap --/M Papiamento roa/pap
|
||||
5 piqd --/M Klingon art/piqd
|
||||
5 pl --/M Polish zlw/pl
|
||||
5 pt --/M Portuguese_(Portugal) roa/pt (pt-pt 5)
|
||||
5 pt-br --/M Portuguese_(Brazil) roa/pt-BR (pt 6)
|
||||
5 py --/M Pyash art/py
|
||||
5 qdb --/M Lang_Belta art/qdb
|
||||
5 qu --/M Quechua qu
|
||||
5 quc --/M K'iche' myn/quc
|
||||
5 qya --/M Quenya art/qya
|
||||
5 ro --/M Romanian roa/ro
|
||||
5 ru --/M Russian zle/ru
|
||||
5 ru-cl --/M Russian_(Classic) zle/ru-cl
|
||||
2 ru-lv --/M Russian_(Latvia) zle/ru-LV
|
||||
5 sd --/M Sindhi inc/sd
|
||||
5 shn --/M Shan_(Tai_Yai) tai/shn
|
||||
5 si --/M Sinhala inc/si
|
||||
5 sjn --/M Sindarin art/sjn
|
||||
5 sk --/M Slovak zlw/sk
|
||||
5 sl --/M Slovenian zls/sl
|
||||
5 smj --/M Lule_Saami urj/smj
|
||||
5 sq --/M Albanian ine/sq
|
||||
5 sr --/M Serbian zls/sr
|
||||
5 sv --/M Swedish gmq/sv
|
||||
5 sw --/M Swahili bnt/sw
|
||||
5 ta --/M Tamil dra/ta
|
||||
5 te --/M Telugu dra/te
|
||||
5 th --/M Thai tai/th
|
||||
5 tk --/M Turkmen trk/tk
|
||||
5 tn --/M Setswana bnt/tn
|
||||
5 tr --/M Turkish trk/tr
|
||||
5 tt --/M Tatar trk/tt
|
||||
5 ug --/M Uyghur trk/ug
|
||||
5 uk --/M Ukrainian zle/uk
|
||||
5 ur --/M Urdu inc/ur
|
||||
5 uz --/M Uzbek trk/uz
|
||||
5 vi --/M Vietnamese_(Northern) aav/vi
|
||||
5 vi-vn-x-central --/M Vietnamese_(Central) aav/vi-VN-x-central
|
||||
5 vi-vn-x-south --/M Vietnamese_(Southern) aav/vi-VN-x-south
|
||||
5 yue --/M Chinese_(Cantonese) sit/yue (zh-yue 5)(zh 8)
|
||||
5 yue --/M Chinese_(Cantonese,_latin_as_Jyutping) sit/yue-Latn-jyutping (zh-yue 5)(zh 8)
|
||||
"""
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/flow.py
|
@ -1,39 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from pypinyin import phrases_dict, pinyin_dict
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
|
||||
def main():
|
||||
filename = "lexicon.txt"
|
||||
tokens = "./data/tokens.txt"
|
||||
tokenizer = Tokenizer(tokens)
|
||||
|
||||
word_dict = pinyin_dict.pinyin_dict
|
||||
phrases = phrases_dict.phrases_dict
|
||||
|
||||
i = 0
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for key in word_dict:
|
||||
if not (0x4E00 <= key <= 0x9FFF):
|
||||
continue
|
||||
|
||||
w = chr(key)
|
||||
|
||||
# 1 to remove the initial sil
|
||||
# :-1 to remove the final eos
|
||||
tokens = tokenizer.text_to_tokens(w)[1:-1]
|
||||
|
||||
tokens = " ".join(tokens)
|
||||
f.write(f"{w} {tokens}\n")
|
||||
|
||||
for key in phrases:
|
||||
# 1 to remove the initial sil
|
||||
# :-1 to remove the final eos
|
||||
tokens = tokenizer.text_to_tokens(key)[1:-1]
|
||||
tokens = " ".join(tokens)
|
||||
f.write(f"{key} {tokens}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/generator.py
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/hifigan.py
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/loss.py
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/monotonic_align
|
@ -1 +0,0 @@
|
||||
../local/pinyin_dict.py
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/posterior_encoder.py
|
@ -1 +0,0 @@
|
||||
../local/pypinyin-local.dict
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/residual_coupling.py
|
@ -1,142 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is used to test the exported onnx model by vits/export-onnx.py
|
||||
|
||||
Use the onnx model to generate a wav:
|
||||
./vits/test_onnx.py \
|
||||
--model-filename vits/exp/vits-epoch-1000.onnx \
|
||||
--tokens data/tokens.txt
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import torchaudio
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the onnx model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/tokens.txt",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
default="Ask not what your country can do for you; ask what you can do for your country.",
|
||||
help="Text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-filename",
|
||||
type=str,
|
||||
default="test_onnx.wav",
|
||||
help="Filename to save the generated wave file.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(self, model_filename: str):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.model = ort.InferenceSession(
|
||||
model_filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||
|
||||
metadata = self.model.get_modelmeta().custom_metadata_map
|
||||
self.sample_rate = int(metadata["sample_rate"])
|
||||
|
||||
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
tokens:
|
||||
A 1-D tensor of shape (1, T)
|
||||
Returns:
|
||||
A tensor of shape (1, T')
|
||||
"""
|
||||
noise_scale = torch.tensor([0.667], dtype=torch.float32)
|
||||
noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
|
||||
alpha = torch.tensor([1.0], dtype=torch.float32)
|
||||
|
||||
out = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: tokens.numpy(),
|
||||
self.model.get_inputs()[1].name: tokens_lens.numpy(),
|
||||
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
||||
self.model.get_inputs()[3].name: alpha.numpy(),
|
||||
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
|
||||
},
|
||||
)[0]
|
||||
return torch.from_numpy(out)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
tokenizer = Tokenizer(args.tokens)
|
||||
|
||||
logging.info("About to create onnx model")
|
||||
model = OnnxModel(args.model_filename)
|
||||
|
||||
text = args.text
|
||||
tokens = tokenizer.texts_to_token_ids([text])
|
||||
tokens = torch.tensor(tokens) # (1, T)
|
||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
||||
audio = model(tokens, tokens_lens) # (1, T')
|
||||
|
||||
output_filename = args.output_filename
|
||||
torchaudio.save(output_filename, audio, sample_rate=model.sample_rate)
|
||||
logging.info(f"Saved to {output_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/text_encoder.py
|
@ -1 +0,0 @@
|
||||
../local/tokenizer.py
|
@ -1,927 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from tokenizer import Tokenizer
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tts_datamodule import BakerZhSpeechTtsDataModule
|
||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||
from vits import VITS
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12354,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Resume training from this epoch. It should be positive.
|
||||
If larger than 1, it will load checkpoint from
|
||||
exp-dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="vits/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/tokens.txt",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr", type=float, default=2.0e-4, help="The base learning rate."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--print-diagnostics",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Accumulate stats on activations, print them and exit.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--inf-check",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Add hooks to check for infinite module outputs and gradients.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=20,
|
||||
help="""Save checkpoint after processing this number of epochs"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.cur_epoch % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
|
||||
Since it will take around 1000 epochs, we suggest using a large
|
||||
save_every_n to save disk space.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="high",
|
||||
choices=["low", "medium", "high"],
|
||||
help="""If not empty, valid values are: low, medium, high.
|
||||
It controls the model size. low -> runs faster.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
are saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_valid_loss: Best validation loss so far. It is used to select
|
||||
the model that has the lowest validation loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_train_epoch: It is the epoch that has the best training loss.
|
||||
|
||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||
|
||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||
contains number of batches trained so far across
|
||||
epochs.
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
# training params
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": -1, # 0
|
||||
"log_interval": 50,
|
||||
"valid_interval": 200,
|
||||
"env_info": get_env_info(),
|
||||
"sampling_rate": 48000,
|
||||
"frame_shift": 256,
|
||||
"frame_length": 1024,
|
||||
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
|
||||
"n_mels": 80,
|
||||
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
||||
"lambda_mel": 45.0, # loss scaling coefficient for Mel loss
|
||||
"lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
|
||||
"lambda_dur": 1.0, # loss scaling coefficient for duration loss
|
||||
"lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict, model: nn.Module
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
Returns:
|
||||
Return a dict containing previously saved training info.
|
||||
"""
|
||||
if params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
saved_params = load_checkpoint(filename, model=model)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def get_model(params: AttributeDict) -> nn.Module:
|
||||
mel_loss_params = {
|
||||
"n_mels": params.n_mels,
|
||||
"frame_length": params.frame_length,
|
||||
"frame_shift": params.frame_shift,
|
||||
}
|
||||
model = VITS(
|
||||
vocab_size=params.vocab_size,
|
||||
feature_dim=params.feature_dim,
|
||||
sampling_rate=params.sampling_rate,
|
||||
model_type=params.model_type,
|
||||
mel_loss_params=mel_loss_params,
|
||||
lambda_adv=params.lambda_adv,
|
||||
lambda_mel=params.lambda_mel,
|
||||
lambda_feat_match=params.lambda_feat_match,
|
||||
lambda_dur=params.lambda_dur,
|
||||
lambda_kl=params.lambda_kl,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||
"""Parse batch data"""
|
||||
audio = batch["audio"].to(device)
|
||||
features = batch["features"].to(device)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
tokens = batch["tokens"]
|
||||
|
||||
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||
tokens = k2.RaggedTensor(tokens)
|
||||
row_splits = tokens.shape.row_splits(1)
|
||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||
tokens = tokens.to(device)
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
# a tensor of shape (B, T)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
||||
|
||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
tokenizer: Tokenizer,
|
||||
optimizer_g: Optimizer,
|
||||
optimizer_d: Optimizer,
|
||||
scheduler_g: LRSchedulerType,
|
||||
scheduler_d: LRSchedulerType,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
tokenizer:
|
||||
Used to convert text to phonemes.
|
||||
optimizer_g:
|
||||
The optimizer for generator.
|
||||
optimizer_d:
|
||||
The optimizer for discriminator.
|
||||
scheduler_g:
|
||||
The learning rate scheduler for generator, we call step() every epoch.
|
||||
scheduler_d:
|
||||
The learning rate scheduler for discriminator, we call step() every epoch.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
rank:
|
||||
The rank of the node in DDP training. If no DDP is used, it should
|
||||
be set to 0.
|
||||
"""
|
||||
model.train()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
||||
# used to track the stats over iterations in one epoch
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
saved_bad_model = False
|
||||
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint(
|
||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||
batch, tokenizer, device
|
||||
)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
# forward discriminator
|
||||
loss_d, stats_d = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=False,
|
||||
)
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
# update discriminator
|
||||
optimizer_d.zero_grad()
|
||||
scaler.scale(loss_d).backward()
|
||||
scaler.step(optimizer_d)
|
||||
|
||||
with autocast(enabled=params.use_fp16):
|
||||
# forward generator
|
||||
loss_g, stats_g = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=True,
|
||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||
)
|
||||
for k, v in stats_g.items():
|
||||
if "returned_sample" not in k:
|
||||
loss_info[k] = v * batch_size
|
||||
# update generator
|
||||
optimizer_g.zero_grad()
|
||||
scaler.scale(loss_g).backward()
|
||||
scaler.step(optimizer_g)
|
||||
scaler.update()
|
||||
|
||||
# summary stats
|
||||
tot_loss = tot_loss + loss_info
|
||||
except: # noqa
|
||||
save_bad_model()
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
return
|
||||
|
||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
if cur_grad_scale < 8.0 or (
|
||||
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
|
||||
):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
save_bad_model(suffix="-first-warning")
|
||||
saved_bad_model = True
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
|
||||
if params.batch_idx_train % params.log_interval == 0:
|
||||
cur_lr_g = max(scheduler_g.get_last_lr())
|
||||
cur_lr_d = max(scheduler_d.get_last_lr())
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
||||
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||
f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate_g", cur_lr_g, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate_d", cur_lr_d, params.batch_idx_train
|
||||
)
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
if "returned_sample" in stats_g:
|
||||
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
||||
tb_writer.add_audio(
|
||||
"train/speech_hat_",
|
||||
speech_hat_,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/speech_",
|
||||
speech_,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/mel_hat_",
|
||||
plot_feature(mel_hat_),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/mel_",
|
||||
plot_feature(mel_),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
|
||||
if (
|
||||
params.batch_idx_train % params.valid_interval == 0
|
||||
and not params.print_diagnostics
|
||||
):
|
||||
logging.info("Computing validation loss")
|
||||
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valdi_speech_hat",
|
||||
speech_hat,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
tb_writer.add_audio(
|
||||
"train/valdi_speech",
|
||||
speech,
|
||||
params.batch_idx_train,
|
||||
params.sampling_rate,
|
||||
)
|
||||
|
||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
tokenizer: Tokenizer,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
||||
# used to summary the stats over iterations
|
||||
tot_loss = MetricsTracker()
|
||||
returned_sample = None
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
batch_size = len(batch["tokens"])
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
) = prepare_input(batch, tokenizer, device)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
# forward discriminator
|
||||
loss_d, stats_d = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=False,
|
||||
)
|
||||
assert loss_d.requires_grad is False
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
# forward generator
|
||||
loss_g, stats_g = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=True,
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
for k, v in stats_g.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
# summary stats
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
# infer for first batch:
|
||||
if batch_idx == 0 and rank == 0:
|
||||
inner_model = model.module if isinstance(model, DDP) else model
|
||||
audio_pred, _, duration = inner_model.inference(
|
||||
text=tokens[0, : tokens_lens[0].item()]
|
||||
)
|
||||
audio_pred = audio_pred.data.cpu().numpy()
|
||||
audio_len_pred = (
|
||||
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||
)
|
||||
assert audio_len_pred == len(audio_pred), (
|
||||
audio_len_pred,
|
||||
len(audio_pred),
|
||||
)
|
||||
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
|
||||
returned_sample = (audio_pred, audio_gt)
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
|
||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss, returned_sample
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: Union[nn.Module, DDP],
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
tokenizer: Tokenizer,
|
||||
optimizer_g: torch.optim.Optimizer,
|
||||
optimizer_d: torch.optim.Optimizer,
|
||||
params: AttributeDict,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
logging.info(
|
||||
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||
batch, tokenizer, device
|
||||
)
|
||||
try:
|
||||
# for discriminator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
loss_d, stats_d = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=False,
|
||||
)
|
||||
optimizer_d.zero_grad()
|
||||
loss_d.backward()
|
||||
# for generator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
loss_g, stats_g = model(
|
||||
text=tokens,
|
||||
text_lengths=tokens_lens,
|
||||
feats=features,
|
||||
feats_lengths=features_lens,
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=True,
|
||||
)
|
||||
optimizer_g.zero_grad()
|
||||
loss_g.backward()
|
||||
except Exception as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
logging.error(
|
||||
"Your GPU ran out of memory with the current "
|
||||
"max_duration setting. We recommend decreasing "
|
||||
"max_duration and trying again.\n"
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
raise
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
generator = model.generator
|
||||
discriminator = model.discriminator
|
||||
|
||||
num_param_g = sum([p.numel() for p in generator.parameters()])
|
||||
logging.info(f"Number of parameters in generator: {num_param_g}")
|
||||
num_param_d = sum([p.numel() for p in discriminator.parameters()])
|
||||
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
||||
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer_g = torch.optim.AdamW(
|
||||
generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
||||
)
|
||||
optimizer_d = torch.optim.AdamW(
|
||||
discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
||||
)
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
|
||||
|
||||
if checkpoints is not None:
|
||||
# load state_dict for optimizers
|
||||
if "optimizer_g" in checkpoints:
|
||||
logging.info("Loading optimizer_g state dict")
|
||||
optimizer_g.load_state_dict(checkpoints["optimizer_g"])
|
||||
if "optimizer_d" in checkpoints:
|
||||
logging.info("Loading optimizer_d state dict")
|
||||
optimizer_d.load_state_dict(checkpoints["optimizer_d"])
|
||||
|
||||
# load state_dict for schedulers
|
||||
if "scheduler_g" in checkpoints:
|
||||
logging.info("Loading scheduler_g state dict")
|
||||
scheduler_g.load_state_dict(checkpoints["scheduler_g"])
|
||||
if "scheduler_d" in checkpoints:
|
||||
logging.info("Loading scheduler_d state dict")
|
||||
scheduler_d.load_state_dict(checkpoints["scheduler_d"])
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
512
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
baker_zh = BakerZhSpeechTtsDataModule(args)
|
||||
|
||||
train_cuts = baker_zh.train_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 20.0:
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
return True
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
train_dl = baker_zh.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = baker_zh.valid_cuts()
|
||||
valid_dl = baker_zh.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
tokenizer=tokenizer,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
logging.info(f"Start epoch {epoch}")
|
||||
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer_g=optimizer_g,
|
||||
optimizer_d=optimizer_d,
|
||||
scheduler_g=scheduler_g,
|
||||
scheduler_d=scheduler_d,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
if rank == 0:
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
# step per epoch
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
BakerZhSpeechTtsDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/transform.py
|
@ -1,330 +0,0 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class BakerZhSpeechTtsDataModule:
|
||||
"""
|
||||
DataModule for tts experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in TTS tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
self.sampling_rate = 48000
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="TTS data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/spectrogram"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = self.sampling_rate
|
||||
config = SpectrogramConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=1024 / sampling_rate, # (in second),
|
||||
frame_shift=256 / sampling_rate, # (in second)
|
||||
use_fft_mag=True,
|
||||
)
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = self.sampling_rate
|
||||
config = SpectrogramConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=1024 / sampling_rate, # (in second),
|
||||
frame_shift=256 / sampling_rate, # (in second)
|
||||
use_fft_mag=True,
|
||||
)
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
num_buckets=self.args.num_buckets,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create valid dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = self.sampling_rate
|
||||
config = SpectrogramConfig(
|
||||
sampling_rate=sampling_rate,
|
||||
frame_length=1024 / sampling_rate, # (in second),
|
||||
frame_shift=256 / sampling_rate, # (in second)
|
||||
use_fft_mag=True,
|
||||
)
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
test_sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
num_buckets=self.args.num_buckets,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=test_sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get validation cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz"
|
||||
)
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/utils.py
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/vits.py
|
@ -1 +0,0 @@
|
||||
../../../ljspeech/TTS/vits/wavenet.py
|
Loading…
x
Reference in New Issue
Block a user