mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
remove baker-zh
This commit is contained in:
parent
f9bd5ced9d
commit
35578f0593
@ -1,7 +0,0 @@
|
|||||||
# Introduction
|
|
||||||
|
|
||||||
[./symbols.py](./symbols.py) is copied from
|
|
||||||
https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py
|
|
||||||
|
|
||||||
[./pypinyin-local.dict](./pypinyin-local.dict) is copied from
|
|
||||||
https://github.com/UEhQZXI/vits_chinese/blob/master/misc/pypinyin-local.dict
|
|
@ -1,106 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
|
||||||
# Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
This file computes fbank features of the baker_zh dataset.
|
|
||||||
It looks for manifests in the directory data/manifests.
|
|
||||||
|
|
||||||
The generated spectrogram features are saved in data/spectrogram.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from lhotse import (
|
|
||||||
CutSet,
|
|
||||||
LilcomChunkyWriter,
|
|
||||||
Spectrogram,
|
|
||||||
SpectrogramConfig,
|
|
||||||
load_manifest,
|
|
||||||
)
|
|
||||||
from lhotse.audio import RecordingSet
|
|
||||||
from lhotse.supervision import SupervisionSet
|
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
|
||||||
# it wastes a lot of CPU and slow things down.
|
|
||||||
# Do this outside of main() in case it needs to take effect
|
|
||||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_spectrogram_baker_zh():
|
|
||||||
src_dir = Path("data/manifests")
|
|
||||||
output_dir = Path("data/spectrogram")
|
|
||||||
num_jobs = min(4, os.cpu_count())
|
|
||||||
|
|
||||||
sampling_rate = 48000
|
|
||||||
frame_length = 1024 / sampling_rate # (in second)
|
|
||||||
frame_shift = 256 / sampling_rate # (in second)
|
|
||||||
use_fft_mag = True
|
|
||||||
|
|
||||||
prefix = "baker_zh"
|
|
||||||
suffix = "jsonl.gz"
|
|
||||||
partition = "all"
|
|
||||||
|
|
||||||
recordings = load_manifest(
|
|
||||||
src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet
|
|
||||||
)
|
|
||||||
supervisions = load_manifest(
|
|
||||||
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
|
|
||||||
)
|
|
||||||
|
|
||||||
config = SpectrogramConfig(
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
frame_length=frame_length,
|
|
||||||
frame_shift=frame_shift,
|
|
||||||
use_fft_mag=use_fft_mag,
|
|
||||||
)
|
|
||||||
extractor = Spectrogram(config)
|
|
||||||
|
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
|
||||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
|
||||||
if (output_dir / cuts_filename).is_file():
|
|
||||||
logging.info(f"{cuts_filename} already exists - skipping.")
|
|
||||||
return
|
|
||||||
logging.info(f"Processing {partition}")
|
|
||||||
cut_set = CutSet.from_manifests(
|
|
||||||
recordings=recordings, supervisions=supervisions
|
|
||||||
)
|
|
||||||
|
|
||||||
cut_set = cut_set.compute_and_store_features(
|
|
||||||
extractor=extractor,
|
|
||||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
|
||||||
# when an executor is specified, make more partitions
|
|
||||||
num_jobs=num_jobs if ex is None else 80,
|
|
||||||
executor=ex,
|
|
||||||
storage_type=LilcomChunkyWriter,
|
|
||||||
)
|
|
||||||
cut_set.to_file(output_dir / cuts_filename)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
compute_spectrogram_baker_zh()
|
|
@ -1,421 +0,0 @@
|
|||||||
# This dict is copied from
|
|
||||||
# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py
|
|
||||||
pinyin_dict = {
|
|
||||||
"a": ("^", "a"),
|
|
||||||
"ai": ("^", "ai"),
|
|
||||||
"an": ("^", "an"),
|
|
||||||
"ang": ("^", "ang"),
|
|
||||||
"ao": ("^", "ao"),
|
|
||||||
"ba": ("b", "a"),
|
|
||||||
"bai": ("b", "ai"),
|
|
||||||
"ban": ("b", "an"),
|
|
||||||
"bang": ("b", "ang"),
|
|
||||||
"bao": ("b", "ao"),
|
|
||||||
"be": ("b", "e"),
|
|
||||||
"bei": ("b", "ei"),
|
|
||||||
"ben": ("b", "en"),
|
|
||||||
"beng": ("b", "eng"),
|
|
||||||
"bi": ("b", "i"),
|
|
||||||
"bian": ("b", "ian"),
|
|
||||||
"biao": ("b", "iao"),
|
|
||||||
"bie": ("b", "ie"),
|
|
||||||
"bin": ("b", "in"),
|
|
||||||
"bing": ("b", "ing"),
|
|
||||||
"bo": ("b", "o"),
|
|
||||||
"bu": ("b", "u"),
|
|
||||||
"ca": ("c", "a"),
|
|
||||||
"cai": ("c", "ai"),
|
|
||||||
"can": ("c", "an"),
|
|
||||||
"cang": ("c", "ang"),
|
|
||||||
"cao": ("c", "ao"),
|
|
||||||
"ce": ("c", "e"),
|
|
||||||
"cen": ("c", "en"),
|
|
||||||
"ceng": ("c", "eng"),
|
|
||||||
"cha": ("ch", "a"),
|
|
||||||
"chai": ("ch", "ai"),
|
|
||||||
"chan": ("ch", "an"),
|
|
||||||
"chang": ("ch", "ang"),
|
|
||||||
"chao": ("ch", "ao"),
|
|
||||||
"che": ("ch", "e"),
|
|
||||||
"chen": ("ch", "en"),
|
|
||||||
"cheng": ("ch", "eng"),
|
|
||||||
"chi": ("ch", "iii"),
|
|
||||||
"chong": ("ch", "ong"),
|
|
||||||
"chou": ("ch", "ou"),
|
|
||||||
"chu": ("ch", "u"),
|
|
||||||
"chua": ("ch", "ua"),
|
|
||||||
"chuai": ("ch", "uai"),
|
|
||||||
"chuan": ("ch", "uan"),
|
|
||||||
"chuang": ("ch", "uang"),
|
|
||||||
"chui": ("ch", "uei"),
|
|
||||||
"chun": ("ch", "uen"),
|
|
||||||
"chuo": ("ch", "uo"),
|
|
||||||
"ci": ("c", "ii"),
|
|
||||||
"cong": ("c", "ong"),
|
|
||||||
"cou": ("c", "ou"),
|
|
||||||
"cu": ("c", "u"),
|
|
||||||
"cuan": ("c", "uan"),
|
|
||||||
"cui": ("c", "uei"),
|
|
||||||
"cun": ("c", "uen"),
|
|
||||||
"cuo": ("c", "uo"),
|
|
||||||
"da": ("d", "a"),
|
|
||||||
"dai": ("d", "ai"),
|
|
||||||
"dan": ("d", "an"),
|
|
||||||
"dang": ("d", "ang"),
|
|
||||||
"dao": ("d", "ao"),
|
|
||||||
"de": ("d", "e"),
|
|
||||||
"dei": ("d", "ei"),
|
|
||||||
"den": ("d", "en"),
|
|
||||||
"deng": ("d", "eng"),
|
|
||||||
"di": ("d", "i"),
|
|
||||||
"dia": ("d", "ia"),
|
|
||||||
"dian": ("d", "ian"),
|
|
||||||
"diao": ("d", "iao"),
|
|
||||||
"die": ("d", "ie"),
|
|
||||||
"ding": ("d", "ing"),
|
|
||||||
"diu": ("d", "iou"),
|
|
||||||
"dong": ("d", "ong"),
|
|
||||||
"dou": ("d", "ou"),
|
|
||||||
"du": ("d", "u"),
|
|
||||||
"duan": ("d", "uan"),
|
|
||||||
"dui": ("d", "uei"),
|
|
||||||
"dun": ("d", "uen"),
|
|
||||||
"duo": ("d", "uo"),
|
|
||||||
"e": ("^", "e"),
|
|
||||||
"ei": ("^", "ei"),
|
|
||||||
"en": ("^", "en"),
|
|
||||||
"ng": ("^", "en"),
|
|
||||||
"eng": ("^", "eng"),
|
|
||||||
"er": ("^", "er"),
|
|
||||||
"fa": ("f", "a"),
|
|
||||||
"fan": ("f", "an"),
|
|
||||||
"fang": ("f", "ang"),
|
|
||||||
"fei": ("f", "ei"),
|
|
||||||
"fen": ("f", "en"),
|
|
||||||
"feng": ("f", "eng"),
|
|
||||||
"fo": ("f", "o"),
|
|
||||||
"fou": ("f", "ou"),
|
|
||||||
"fu": ("f", "u"),
|
|
||||||
"ga": ("g", "a"),
|
|
||||||
"gai": ("g", "ai"),
|
|
||||||
"gan": ("g", "an"),
|
|
||||||
"gang": ("g", "ang"),
|
|
||||||
"gao": ("g", "ao"),
|
|
||||||
"ge": ("g", "e"),
|
|
||||||
"gei": ("g", "ei"),
|
|
||||||
"gen": ("g", "en"),
|
|
||||||
"geng": ("g", "eng"),
|
|
||||||
"gong": ("g", "ong"),
|
|
||||||
"gou": ("g", "ou"),
|
|
||||||
"gu": ("g", "u"),
|
|
||||||
"gua": ("g", "ua"),
|
|
||||||
"guai": ("g", "uai"),
|
|
||||||
"guan": ("g", "uan"),
|
|
||||||
"guang": ("g", "uang"),
|
|
||||||
"gui": ("g", "uei"),
|
|
||||||
"gun": ("g", "uen"),
|
|
||||||
"guo": ("g", "uo"),
|
|
||||||
"ha": ("h", "a"),
|
|
||||||
"hai": ("h", "ai"),
|
|
||||||
"han": ("h", "an"),
|
|
||||||
"hang": ("h", "ang"),
|
|
||||||
"hao": ("h", "ao"),
|
|
||||||
"he": ("h", "e"),
|
|
||||||
"hei": ("h", "ei"),
|
|
||||||
"hen": ("h", "en"),
|
|
||||||
"heng": ("h", "eng"),
|
|
||||||
"hong": ("h", "ong"),
|
|
||||||
"hou": ("h", "ou"),
|
|
||||||
"hu": ("h", "u"),
|
|
||||||
"hua": ("h", "ua"),
|
|
||||||
"huai": ("h", "uai"),
|
|
||||||
"huan": ("h", "uan"),
|
|
||||||
"huang": ("h", "uang"),
|
|
||||||
"hui": ("h", "uei"),
|
|
||||||
"hun": ("h", "uen"),
|
|
||||||
"huo": ("h", "uo"),
|
|
||||||
"ji": ("j", "i"),
|
|
||||||
"jia": ("j", "ia"),
|
|
||||||
"jian": ("j", "ian"),
|
|
||||||
"jiang": ("j", "iang"),
|
|
||||||
"jiao": ("j", "iao"),
|
|
||||||
"jie": ("j", "ie"),
|
|
||||||
"jin": ("j", "in"),
|
|
||||||
"jing": ("j", "ing"),
|
|
||||||
"jiong": ("j", "iong"),
|
|
||||||
"jiu": ("j", "iou"),
|
|
||||||
"ju": ("j", "v"),
|
|
||||||
"juan": ("j", "van"),
|
|
||||||
"jue": ("j", "ve"),
|
|
||||||
"jun": ("j", "vn"),
|
|
||||||
"ka": ("k", "a"),
|
|
||||||
"kai": ("k", "ai"),
|
|
||||||
"kan": ("k", "an"),
|
|
||||||
"kang": ("k", "ang"),
|
|
||||||
"kao": ("k", "ao"),
|
|
||||||
"ke": ("k", "e"),
|
|
||||||
"kei": ("k", "ei"),
|
|
||||||
"ken": ("k", "en"),
|
|
||||||
"keng": ("k", "eng"),
|
|
||||||
"kong": ("k", "ong"),
|
|
||||||
"kou": ("k", "ou"),
|
|
||||||
"ku": ("k", "u"),
|
|
||||||
"kua": ("k", "ua"),
|
|
||||||
"kuai": ("k", "uai"),
|
|
||||||
"kuan": ("k", "uan"),
|
|
||||||
"kuang": ("k", "uang"),
|
|
||||||
"kui": ("k", "uei"),
|
|
||||||
"kun": ("k", "uen"),
|
|
||||||
"kuo": ("k", "uo"),
|
|
||||||
"la": ("l", "a"),
|
|
||||||
"lai": ("l", "ai"),
|
|
||||||
"lan": ("l", "an"),
|
|
||||||
"lang": ("l", "ang"),
|
|
||||||
"lao": ("l", "ao"),
|
|
||||||
"le": ("l", "e"),
|
|
||||||
"lei": ("l", "ei"),
|
|
||||||
"leng": ("l", "eng"),
|
|
||||||
"li": ("l", "i"),
|
|
||||||
"lia": ("l", "ia"),
|
|
||||||
"lian": ("l", "ian"),
|
|
||||||
"liang": ("l", "iang"),
|
|
||||||
"liao": ("l", "iao"),
|
|
||||||
"lie": ("l", "ie"),
|
|
||||||
"lin": ("l", "in"),
|
|
||||||
"ling": ("l", "ing"),
|
|
||||||
"liu": ("l", "iou"),
|
|
||||||
"lo": ("l", "o"),
|
|
||||||
"long": ("l", "ong"),
|
|
||||||
"lou": ("l", "ou"),
|
|
||||||
"lu": ("l", "u"),
|
|
||||||
"lv": ("l", "v"),
|
|
||||||
"luan": ("l", "uan"),
|
|
||||||
"lve": ("l", "ve"),
|
|
||||||
"lue": ("l", "ve"),
|
|
||||||
"lun": ("l", "uen"),
|
|
||||||
"luo": ("l", "uo"),
|
|
||||||
"ma": ("m", "a"),
|
|
||||||
"mai": ("m", "ai"),
|
|
||||||
"man": ("m", "an"),
|
|
||||||
"mang": ("m", "ang"),
|
|
||||||
"mao": ("m", "ao"),
|
|
||||||
"me": ("m", "e"),
|
|
||||||
"mei": ("m", "ei"),
|
|
||||||
"men": ("m", "en"),
|
|
||||||
"meng": ("m", "eng"),
|
|
||||||
"mi": ("m", "i"),
|
|
||||||
"mian": ("m", "ian"),
|
|
||||||
"miao": ("m", "iao"),
|
|
||||||
"mie": ("m", "ie"),
|
|
||||||
"min": ("m", "in"),
|
|
||||||
"ming": ("m", "ing"),
|
|
||||||
"miu": ("m", "iou"),
|
|
||||||
"mo": ("m", "o"),
|
|
||||||
"mou": ("m", "ou"),
|
|
||||||
"mu": ("m", "u"),
|
|
||||||
"na": ("n", "a"),
|
|
||||||
"nai": ("n", "ai"),
|
|
||||||
"nan": ("n", "an"),
|
|
||||||
"nang": ("n", "ang"),
|
|
||||||
"nao": ("n", "ao"),
|
|
||||||
"ne": ("n", "e"),
|
|
||||||
"nei": ("n", "ei"),
|
|
||||||
"nen": ("n", "en"),
|
|
||||||
"neng": ("n", "eng"),
|
|
||||||
"ni": ("n", "i"),
|
|
||||||
"nia": ("n", "ia"),
|
|
||||||
"nian": ("n", "ian"),
|
|
||||||
"niang": ("n", "iang"),
|
|
||||||
"niao": ("n", "iao"),
|
|
||||||
"nie": ("n", "ie"),
|
|
||||||
"nin": ("n", "in"),
|
|
||||||
"ning": ("n", "ing"),
|
|
||||||
"niu": ("n", "iou"),
|
|
||||||
"nong": ("n", "ong"),
|
|
||||||
"nou": ("n", "ou"),
|
|
||||||
"nu": ("n", "u"),
|
|
||||||
"nv": ("n", "v"),
|
|
||||||
"nuan": ("n", "uan"),
|
|
||||||
"nve": ("n", "ve"),
|
|
||||||
"nue": ("n", "ve"),
|
|
||||||
"nuo": ("n", "uo"),
|
|
||||||
"o": ("^", "o"),
|
|
||||||
"ou": ("^", "ou"),
|
|
||||||
"pa": ("p", "a"),
|
|
||||||
"pai": ("p", "ai"),
|
|
||||||
"pan": ("p", "an"),
|
|
||||||
"pang": ("p", "ang"),
|
|
||||||
"pao": ("p", "ao"),
|
|
||||||
"pe": ("p", "e"),
|
|
||||||
"pei": ("p", "ei"),
|
|
||||||
"pen": ("p", "en"),
|
|
||||||
"peng": ("p", "eng"),
|
|
||||||
"pi": ("p", "i"),
|
|
||||||
"pian": ("p", "ian"),
|
|
||||||
"piao": ("p", "iao"),
|
|
||||||
"pie": ("p", "ie"),
|
|
||||||
"pin": ("p", "in"),
|
|
||||||
"ping": ("p", "ing"),
|
|
||||||
"po": ("p", "o"),
|
|
||||||
"pou": ("p", "ou"),
|
|
||||||
"pu": ("p", "u"),
|
|
||||||
"qi": ("q", "i"),
|
|
||||||
"qia": ("q", "ia"),
|
|
||||||
"qian": ("q", "ian"),
|
|
||||||
"qiang": ("q", "iang"),
|
|
||||||
"qiao": ("q", "iao"),
|
|
||||||
"qie": ("q", "ie"),
|
|
||||||
"qin": ("q", "in"),
|
|
||||||
"qing": ("q", "ing"),
|
|
||||||
"qiong": ("q", "iong"),
|
|
||||||
"qiu": ("q", "iou"),
|
|
||||||
"qu": ("q", "v"),
|
|
||||||
"quan": ("q", "van"),
|
|
||||||
"que": ("q", "ve"),
|
|
||||||
"qun": ("q", "vn"),
|
|
||||||
"ran": ("r", "an"),
|
|
||||||
"rang": ("r", "ang"),
|
|
||||||
"rao": ("r", "ao"),
|
|
||||||
"re": ("r", "e"),
|
|
||||||
"ren": ("r", "en"),
|
|
||||||
"reng": ("r", "eng"),
|
|
||||||
"ri": ("r", "iii"),
|
|
||||||
"rong": ("r", "ong"),
|
|
||||||
"rou": ("r", "ou"),
|
|
||||||
"ru": ("r", "u"),
|
|
||||||
"rua": ("r", "ua"),
|
|
||||||
"ruan": ("r", "uan"),
|
|
||||||
"rui": ("r", "uei"),
|
|
||||||
"run": ("r", "uen"),
|
|
||||||
"ruo": ("r", "uo"),
|
|
||||||
"sa": ("s", "a"),
|
|
||||||
"sai": ("s", "ai"),
|
|
||||||
"san": ("s", "an"),
|
|
||||||
"sang": ("s", "ang"),
|
|
||||||
"sao": ("s", "ao"),
|
|
||||||
"se": ("s", "e"),
|
|
||||||
"sen": ("s", "en"),
|
|
||||||
"seng": ("s", "eng"),
|
|
||||||
"sha": ("sh", "a"),
|
|
||||||
"shai": ("sh", "ai"),
|
|
||||||
"shan": ("sh", "an"),
|
|
||||||
"shang": ("sh", "ang"),
|
|
||||||
"shao": ("sh", "ao"),
|
|
||||||
"she": ("sh", "e"),
|
|
||||||
"shei": ("sh", "ei"),
|
|
||||||
"shen": ("sh", "en"),
|
|
||||||
"sheng": ("sh", "eng"),
|
|
||||||
"shi": ("sh", "iii"),
|
|
||||||
"shou": ("sh", "ou"),
|
|
||||||
"shu": ("sh", "u"),
|
|
||||||
"shua": ("sh", "ua"),
|
|
||||||
"shuai": ("sh", "uai"),
|
|
||||||
"shuan": ("sh", "uan"),
|
|
||||||
"shuang": ("sh", "uang"),
|
|
||||||
"shui": ("sh", "uei"),
|
|
||||||
"shun": ("sh", "uen"),
|
|
||||||
"shuo": ("sh", "uo"),
|
|
||||||
"si": ("s", "ii"),
|
|
||||||
"song": ("s", "ong"),
|
|
||||||
"sou": ("s", "ou"),
|
|
||||||
"su": ("s", "u"),
|
|
||||||
"suan": ("s", "uan"),
|
|
||||||
"sui": ("s", "uei"),
|
|
||||||
"sun": ("s", "uen"),
|
|
||||||
"suo": ("s", "uo"),
|
|
||||||
"ta": ("t", "a"),
|
|
||||||
"tai": ("t", "ai"),
|
|
||||||
"tan": ("t", "an"),
|
|
||||||
"tang": ("t", "ang"),
|
|
||||||
"tao": ("t", "ao"),
|
|
||||||
"te": ("t", "e"),
|
|
||||||
"tei": ("t", "ei"),
|
|
||||||
"teng": ("t", "eng"),
|
|
||||||
"ti": ("t", "i"),
|
|
||||||
"tian": ("t", "ian"),
|
|
||||||
"tiao": ("t", "iao"),
|
|
||||||
"tie": ("t", "ie"),
|
|
||||||
"ting": ("t", "ing"),
|
|
||||||
"tong": ("t", "ong"),
|
|
||||||
"tou": ("t", "ou"),
|
|
||||||
"tu": ("t", "u"),
|
|
||||||
"tuan": ("t", "uan"),
|
|
||||||
"tui": ("t", "uei"),
|
|
||||||
"tun": ("t", "uen"),
|
|
||||||
"tuo": ("t", "uo"),
|
|
||||||
"wa": ("^", "ua"),
|
|
||||||
"wai": ("^", "uai"),
|
|
||||||
"wan": ("^", "uan"),
|
|
||||||
"wang": ("^", "uang"),
|
|
||||||
"wei": ("^", "uei"),
|
|
||||||
"wen": ("^", "uen"),
|
|
||||||
"weng": ("^", "ueng"),
|
|
||||||
"wo": ("^", "uo"),
|
|
||||||
"wu": ("^", "u"),
|
|
||||||
"xi": ("x", "i"),
|
|
||||||
"xia": ("x", "ia"),
|
|
||||||
"xian": ("x", "ian"),
|
|
||||||
"xiang": ("x", "iang"),
|
|
||||||
"xiao": ("x", "iao"),
|
|
||||||
"xie": ("x", "ie"),
|
|
||||||
"xin": ("x", "in"),
|
|
||||||
"xing": ("x", "ing"),
|
|
||||||
"xiong": ("x", "iong"),
|
|
||||||
"xiu": ("x", "iou"),
|
|
||||||
"xu": ("x", "v"),
|
|
||||||
"xuan": ("x", "van"),
|
|
||||||
"xue": ("x", "ve"),
|
|
||||||
"xun": ("x", "vn"),
|
|
||||||
"ya": ("^", "ia"),
|
|
||||||
"yan": ("^", "ian"),
|
|
||||||
"yang": ("^", "iang"),
|
|
||||||
"yao": ("^", "iao"),
|
|
||||||
"ye": ("^", "ie"),
|
|
||||||
"yi": ("^", "i"),
|
|
||||||
"yin": ("^", "in"),
|
|
||||||
"ying": ("^", "ing"),
|
|
||||||
"yo": ("^", "iou"),
|
|
||||||
"yong": ("^", "iong"),
|
|
||||||
"you": ("^", "iou"),
|
|
||||||
"yu": ("^", "v"),
|
|
||||||
"yuan": ("^", "van"),
|
|
||||||
"yue": ("^", "ve"),
|
|
||||||
"yun": ("^", "vn"),
|
|
||||||
"za": ("z", "a"),
|
|
||||||
"zai": ("z", "ai"),
|
|
||||||
"zan": ("z", "an"),
|
|
||||||
"zang": ("z", "ang"),
|
|
||||||
"zao": ("z", "ao"),
|
|
||||||
"ze": ("z", "e"),
|
|
||||||
"zei": ("z", "ei"),
|
|
||||||
"zen": ("z", "en"),
|
|
||||||
"zeng": ("z", "eng"),
|
|
||||||
"zha": ("zh", "a"),
|
|
||||||
"zhai": ("zh", "ai"),
|
|
||||||
"zhan": ("zh", "an"),
|
|
||||||
"zhang": ("zh", "ang"),
|
|
||||||
"zhao": ("zh", "ao"),
|
|
||||||
"zhe": ("zh", "e"),
|
|
||||||
"zhei": ("zh", "ei"),
|
|
||||||
"zhen": ("zh", "en"),
|
|
||||||
"zheng": ("zh", "eng"),
|
|
||||||
"zhi": ("zh", "iii"),
|
|
||||||
"zhong": ("zh", "ong"),
|
|
||||||
"zhou": ("zh", "ou"),
|
|
||||||
"zhu": ("zh", "u"),
|
|
||||||
"zhua": ("zh", "ua"),
|
|
||||||
"zhuai": ("zh", "uai"),
|
|
||||||
"zhuan": ("zh", "uan"),
|
|
||||||
"zhuang": ("zh", "uang"),
|
|
||||||
"zhui": ("zh", "uei"),
|
|
||||||
"zhun": ("zh", "uen"),
|
|
||||||
"zhuo": ("zh", "uo"),
|
|
||||||
"zi": ("z", "ii"),
|
|
||||||
"zong": ("z", "ong"),
|
|
||||||
"zou": ("z", "ou"),
|
|
||||||
"zu": ("z", "u"),
|
|
||||||
"zuan": ("z", "uan"),
|
|
||||||
"zui": ("z", "uei"),
|
|
||||||
"zun": ("z", "uen"),
|
|
||||||
"zuo": ("z", "uo"),
|
|
||||||
}
|
|
@ -1,53 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
This file generates the file that maps tokens to IDs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict
|
|
||||||
from symbols import symbols
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokens",
|
|
||||||
type=Path,
|
|
||||||
default=Path("data/tokens.txt"),
|
|
||||||
help="Path to the dict that maps the text tokens to IDs",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = get_args()
|
|
||||||
tokens = Path(args.tokens)
|
|
||||||
|
|
||||||
with open(tokens, "w", encoding="utf-8") as f:
|
|
||||||
for token_id, token in enumerate(symbols):
|
|
||||||
f.write(f"{token} {token_id}\n")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,59 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
This file reads the texts in given manifest and save the new cuts with tokens.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from lhotse import CutSet, load_manifest
|
|
||||||
|
|
||||||
from tokenizer import Tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_tokens_baker_zh():
|
|
||||||
output_dir = Path("data/spectrogram")
|
|
||||||
prefix = "baker_zh"
|
|
||||||
suffix = "jsonl.gz"
|
|
||||||
partition = "all"
|
|
||||||
|
|
||||||
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
|
||||||
|
|
||||||
tokenizer = Tokenizer()
|
|
||||||
|
|
||||||
new_cuts = []
|
|
||||||
i = 0
|
|
||||||
for cut in cut_set:
|
|
||||||
# Each cut only contains one supervision
|
|
||||||
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
|
|
||||||
text = cut.supervisions[0].normalized_text
|
|
||||||
cut.tokens = tokenizer.text_to_tokens(text)
|
|
||||||
|
|
||||||
new_cuts.append(cut)
|
|
||||||
|
|
||||||
new_cut_set = CutSet.from_cuts(new_cuts)
|
|
||||||
new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
|
|
||||||
prepare_tokens_baker_zh()
|
|
@ -1,328 +0,0 @@
|
|||||||
姐姐 jie3 jie
|
|
||||||
宝宝 bao3 bao
|
|
||||||
哥哥 ge1 ge
|
|
||||||
妹妹 mei4 mei
|
|
||||||
弟弟 di4 di
|
|
||||||
妈妈 ma1 ma
|
|
||||||
开心哦 kai1 xin1 o
|
|
||||||
爸爸 ba4 ba
|
|
||||||
秘密哟 mi4 mi4 yo
|
|
||||||
哦 o
|
|
||||||
一年 yi4 nian2
|
|
||||||
一夜 yi2 ye4
|
|
||||||
一切 yi2 qie4
|
|
||||||
一座 yi2 zuo4
|
|
||||||
一下 yi2 xia4
|
|
||||||
上一山 shang4 yi2 shan1
|
|
||||||
下一山 xia4 yi2 shan1
|
|
||||||
休息 xiu1 xi2
|
|
||||||
东西 dong1 xi
|
|
||||||
上一届 shang4 yi2 jie4
|
|
||||||
便宜 pian2 yi4
|
|
||||||
加长 jia1 chang2
|
|
||||||
单田芳 shan4 tian2 fang1
|
|
||||||
帧 zhen1
|
|
||||||
长时间 chang2 shi2 jian1
|
|
||||||
长时 chang2 shi2
|
|
||||||
识别 shi2 bie2
|
|
||||||
生命中 sheng1 ming4 zhong1
|
|
||||||
踏实 ta1 shi
|
|
||||||
嗯 en4
|
|
||||||
溜达 liu1 da
|
|
||||||
少儿 shao4 er2
|
|
||||||
爷爷 ye2 ye
|
|
||||||
不是 bu2 shi4
|
|
||||||
一圈 yi1 quan1
|
|
||||||
厜读一声 zui1 du2 yi4 sheng1
|
|
||||||
一种 yi4 zhong3
|
|
||||||
一簇簇 yi2 cu4 cu4
|
|
||||||
一个 yi2 ge4
|
|
||||||
一样 yi2 yang4
|
|
||||||
一跩一跩 yi4 zhuai3 yi4 zhuai3
|
|
||||||
一会儿 yi2 hui4 er
|
|
||||||
一幢 yi2 zhuang4
|
|
||||||
挨了 ai2 le
|
|
||||||
熬菜 ao1 cai4
|
|
||||||
扒鸡 pa2 ji1
|
|
||||||
背枪 bei1 qiang1
|
|
||||||
绷瓷儿 beng4 ci2 er2
|
|
||||||
绷劲儿 beng3 jin4 er
|
|
||||||
绷着脸 beng3 zhe lian3
|
|
||||||
藏医 zang4 yi1
|
|
||||||
噌吰 cheng1 hong2
|
|
||||||
差点儿 cha4 dian3 er
|
|
||||||
差失 cha1 shi1
|
|
||||||
差误 cha1 wu4
|
|
||||||
孱头 can4 tou
|
|
||||||
乘间 cheng2 jian4
|
|
||||||
锄镰棘矜 chu2 lian2 ji2 qin2
|
|
||||||
川藏 chuan1 zang4
|
|
||||||
穿著 chuan1 zhuo2
|
|
||||||
答讪 da1 shan4
|
|
||||||
答言 da1 yan2
|
|
||||||
大伯子 da4 bai3 zi
|
|
||||||
大夫 dai4 fu
|
|
||||||
弹冠 tan2 guan1
|
|
||||||
当间 dang1 jian4
|
|
||||||
当然咯 dang1 ran2 lo
|
|
||||||
点种 dian3 zhong3
|
|
||||||
垛好 duo4 hao3
|
|
||||||
发疟子 fa1 yao4 zi
|
|
||||||
饭熟了 fan4 shou2 le
|
|
||||||
附著 fu4 zhuo2
|
|
||||||
复沓 fu4 ta4
|
|
||||||
供稿 gong1 gao3
|
|
||||||
供养 gong1 yang3
|
|
||||||
骨朵 gu1 duo
|
|
||||||
骨碌 gu1 lu
|
|
||||||
果脯 guo3 fu3
|
|
||||||
哈什玛 ha4 shi2 ma3
|
|
||||||
海蜇 hai3 zhe2
|
|
||||||
呵欠 he1 qian
|
|
||||||
河水汤汤 he2 shui3 shang1 shang1
|
|
||||||
鹄立 hu2 li4
|
|
||||||
鹄望 hu2 wang4
|
|
||||||
混人 hun2 ren2
|
|
||||||
混水 hun2 shui3
|
|
||||||
鸡血 ji1 xie3
|
|
||||||
缉鞋口 qi1 xie2 kou3
|
|
||||||
亟来闻讯 qi4 lai2 wen2 xun4
|
|
||||||
计量 ji4 liang2
|
|
||||||
济水 ji3 shui3
|
|
||||||
间杂 jian4 za2
|
|
||||||
脚跐两只船 jiao3 ci3 liang3 zhi1 chuan2
|
|
||||||
脚儿 jue2 er2
|
|
||||||
口角 kou3 jiao3
|
|
||||||
勒石 le4 shi2
|
|
||||||
累进 lei3 jin4
|
|
||||||
累累如丧家之犬 lei2 lei2 ru2 sang4 jia1 zhi1 quan3
|
|
||||||
累年 lei3 nian2
|
|
||||||
脸涨通红 lian3 zhang4 tong1 hong2
|
|
||||||
踉锵 liang4 qiang1
|
|
||||||
燎眉毛 liao3 mei2 mao2
|
|
||||||
燎头发 liao3 tou2 fa4
|
|
||||||
溜达 liu1 da
|
|
||||||
溜缝儿 liu4 feng4 er
|
|
||||||
馏口饭 liu4 kou3 fan4
|
|
||||||
遛马 liu4 ma3
|
|
||||||
遛鸟 liu4 niao3
|
|
||||||
遛弯儿 liu4 wan1 er
|
|
||||||
楼枪机 lou1 qiang1 ji1
|
|
||||||
搂钱 lou1 qian2
|
|
||||||
鹿脯 lu4 fu3
|
|
||||||
露头 lou4 tou2
|
|
||||||
落魄 luo4 po4
|
|
||||||
捋胡子 lv3 hu2 zi
|
|
||||||
绿地 lv4 di4
|
|
||||||
麦垛 mai4 duo4
|
|
||||||
没劲儿 mei2 jin4 er
|
|
||||||
闷棍 men4 gun4
|
|
||||||
闷葫芦 men4 hu2 lu
|
|
||||||
闷头干 men1 tou2 gan4
|
|
||||||
蒙古 meng3 gu3
|
|
||||||
靡日不思 mi3 ri4 bu4 si1
|
|
||||||
缪姓 miao4 xing4
|
|
||||||
抹墙 mo4 qiang2
|
|
||||||
抹下脸 ma1 xia4 lian3
|
|
||||||
泥子 ni4 zi
|
|
||||||
拗不过 niu4 bu guo4
|
|
||||||
排车 pai3 che1
|
|
||||||
盘诘 pan2 jie2
|
|
||||||
膀肿 pang1 zhong3
|
|
||||||
炮干 bao1 gan1
|
|
||||||
炮格 pao2 ge2
|
|
||||||
碰钉子 peng4 ding1 zi
|
|
||||||
缥色 piao3 se4
|
|
||||||
瀑河 bao4 he2
|
|
||||||
蹊径 xi1 jing4
|
|
||||||
前后相属 qian2 hou4 xiang1 zhu3
|
|
||||||
翘尾巴 qiao4 wei3 ba
|
|
||||||
趄坡儿 qie4 po1 er
|
|
||||||
秦桧 qin2 hui4
|
|
||||||
圈马 juan1 ma3
|
|
||||||
雀盲眼 qiao3 mang2 yan3
|
|
||||||
雀子 qiao1 zi
|
|
||||||
三年五载 san1 nian2 wu3 zai3
|
|
||||||
加载 jia1 zai3
|
|
||||||
山大王 shan1 dai4 wang
|
|
||||||
苫屋草 shan4 wu1 cao3
|
|
||||||
数数 shu3 shu4
|
|
||||||
说客 shui4 ke4
|
|
||||||
思量 si1 liang2
|
|
||||||
伺侯 ci4 hou
|
|
||||||
踏实 ta1 shi
|
|
||||||
提溜 di1 liu
|
|
||||||
调拨 diao4 bo1
|
|
||||||
帖子 tie3 zi
|
|
||||||
铜钿 tong2 tian2
|
|
||||||
头昏脑涨 tou2 hun1 nao3 zhang4
|
|
||||||
褪色 tui4 se4
|
|
||||||
褪着手 tun4 zhe shou3
|
|
||||||
圩子 wei2 zi
|
|
||||||
尾巴 wei3 ba
|
|
||||||
系好船只 xi4 hao3 chuan2 zhi1
|
|
||||||
系好马匹 xi4 hao3 ma3 pi3
|
|
||||||
杏脯 xing4 fu3
|
|
||||||
姓单 xing4 shan4
|
|
||||||
姓葛 xing4 ge3
|
|
||||||
姓哈 xing4 ha3
|
|
||||||
姓解 xing4 xie4
|
|
||||||
姓秘 xing4 bi4
|
|
||||||
姓宁 xing4 ning4
|
|
||||||
旋风 xuan4 feng1
|
|
||||||
旋根车轴 xuan4 gen1 che1 zhou2
|
|
||||||
荨麻 qian2 ma2
|
|
||||||
一幢楼房 yi1 zhuang4 lou2 fang2
|
|
||||||
遗之千金 wei4 zhi1 qian1 jin1
|
|
||||||
殷殷 yin3 yin3
|
|
||||||
应招 ying4 zhao1
|
|
||||||
用称约 yong4 cheng4 yao1
|
|
||||||
约斤肉 yao1 jin1 rou4
|
|
||||||
晕机 yun4 ji1
|
|
||||||
熨贴 yu4 tie1
|
|
||||||
咋办 za3 ban4
|
|
||||||
咋呼 zha1 hu
|
|
||||||
仔兽 zi3 shou4
|
|
||||||
扎彩 za1 cai3
|
|
||||||
扎实 zha1 shi
|
|
||||||
扎腰带 za1 yao1 dai4
|
|
||||||
轧朋友 ga2 peng2 you3
|
|
||||||
爪子 zhua3 zi
|
|
||||||
折腾 zhe1 teng
|
|
||||||
着实 zhuo2 shi2
|
|
||||||
着我旧时裳 zhuo2 wo3 jiu4 shi2 chang2
|
|
||||||
枝蔓 zhi1 man4
|
|
||||||
中鹄 zhong1 hu2
|
|
||||||
中选 zhong4 xuan3
|
|
||||||
猪圈 zhu1 juan4
|
|
||||||
拽住不放 zhuai4 zhu4 bu4 fang4
|
|
||||||
转悠 zhuan4 you
|
|
||||||
庄稼熟了 zhuang1 jia shou2 le
|
|
||||||
酌量 zhuo2 liang2
|
|
||||||
罪行累累 zui4 xing2 lei3 lei3
|
|
||||||
一手 yi4 shou3
|
|
||||||
一去不复返 yi2 qu4 bu2 fu4 fan3
|
|
||||||
一颗 yi4 ke1
|
|
||||||
一件 yi2 jian4
|
|
||||||
一斤 yi4 jin1
|
|
||||||
一点 yi4 dian3
|
|
||||||
一朵 yi4 duo3
|
|
||||||
一声 yi4 sheng1
|
|
||||||
一身 yi4 shen1
|
|
||||||
不要 bu2 yao4
|
|
||||||
一人 yi4 ren2
|
|
||||||
一个 yi2 ge4
|
|
||||||
一把 yi4 ba3
|
|
||||||
一门 yi4 men2
|
|
||||||
一門 yi4 men2
|
|
||||||
一艘 yi4 sou1
|
|
||||||
一片 yi2 pian4
|
|
||||||
一篇 yi2 pian1
|
|
||||||
一份 yi2 fen4
|
|
||||||
好嗲 hao3 dia3
|
|
||||||
随地 sui2 di4
|
|
||||||
扁担长 bian3 dan4 chang3
|
|
||||||
一堆 yi4 dui1
|
|
||||||
不义 bu2 yi4
|
|
||||||
放一放 fang4 yi2 fang4
|
|
||||||
一米 yi4 mi3
|
|
||||||
一顿 yi2 dun4
|
|
||||||
一层楼 yi4 ceng2 lou2
|
|
||||||
一条 yi4 tiao2
|
|
||||||
一件 yi2 jian4
|
|
||||||
一棵 yi4 ke1
|
|
||||||
一小股 yi4 xiao3 gu3
|
|
||||||
一拐一拐 yi4 guai3 yi4 guai3
|
|
||||||
一根 yi4 gen1
|
|
||||||
沆瀣一气 hang4 xie4 yi2 qi4
|
|
||||||
一丝 yi4 si1
|
|
||||||
一毫 yi4 hao2
|
|
||||||
一樣 yi2 yang4
|
|
||||||
处处 chu4 chu4
|
|
||||||
一餐 yi4 can
|
|
||||||
永不 yong3 bu2
|
|
||||||
一看 yi2 kan4
|
|
||||||
一架 yi2 jia4
|
|
||||||
送还 song4 huan2
|
|
||||||
一见 yi2 jian4
|
|
||||||
一座 yi2 zuo4
|
|
||||||
一块 yi2 kuai4
|
|
||||||
一天 yi4 tian1
|
|
||||||
一只 yi4 zhi1
|
|
||||||
一支 yi4 zhi1
|
|
||||||
一字 yi2 zi4
|
|
||||||
一句 yi2 ju4
|
|
||||||
一张 yi4 zhang1
|
|
||||||
一條 yi4 tiao2
|
|
||||||
一场 yi4 chang3
|
|
||||||
一粒 yi2 li4
|
|
||||||
小俩口 xiao3 liang3 kou3
|
|
||||||
一首 yi4 shou3
|
|
||||||
一对 yi2 dui4
|
|
||||||
一手 yi4 shou3
|
|
||||||
又一村 you4 yi4 cun1
|
|
||||||
一概而论 yi2 gai4 er2 lun4
|
|
||||||
一峰峰 yi4 feng1 feng1
|
|
||||||
不但 bu2 dan4
|
|
||||||
一笑 yi2 xiao4
|
|
||||||
挠痒痒 nao2 yang3 yang
|
|
||||||
不对 bu2 dui4
|
|
||||||
拧开 ning3 kai1
|
|
||||||
爱不释手 ai4 bu2 shi4 shou3
|
|
||||||
一念 yi2 nian4
|
|
||||||
夺得 duo2 de2
|
|
||||||
一袭 yi4 xi2
|
|
||||||
一定 yi2 ding4
|
|
||||||
不慎 bu2 shen4
|
|
||||||
剽窃 piao2 qie4
|
|
||||||
一时 yi4 shi2
|
|
||||||
撇开 pie3 kai1
|
|
||||||
一祭 yi2 ji4
|
|
||||||
发卡 fa4 qia3
|
|
||||||
少不了 shao3 bu4 liao3
|
|
||||||
千虑一失 qian1 lv4 yi4 shi1
|
|
||||||
呛得 qiang4 de2
|
|
||||||
切菜 qie1 cai4
|
|
||||||
茄盒 qie2 he2
|
|
||||||
不去 bu2 qu4
|
|
||||||
一大圈 yi2 da4 quan1
|
|
||||||
不再 bu2 zai4
|
|
||||||
一群 yi4 qun2
|
|
||||||
不必 bu2 bi4
|
|
||||||
一些 yi4 xie1
|
|
||||||
一路 yi2 lu4
|
|
||||||
一股 yi4 gu3
|
|
||||||
一到 yi2 dao4
|
|
||||||
一拨 yi4 bo1
|
|
||||||
一排 yi4 pai2
|
|
||||||
一空 yi4 kong1
|
|
||||||
吮吸着 shun3 xi1 zhe
|
|
||||||
不适合 bu2 shi4 he2
|
|
||||||
一串串 yi2 chuan4 chuan4
|
|
||||||
一提起 yi4 ti2 qi3
|
|
||||||
一尘不染 yi4 chen2 bu4 ran3
|
|
||||||
一生 yi4 sheng1
|
|
||||||
一派 yi2 pai4
|
|
||||||
不断 bu2 duan4
|
|
||||||
一次 yi2 ci4
|
|
||||||
不进步 bu2 jin4 bu4
|
|
||||||
娃娃 wa2 wa
|
|
||||||
万户侯 wan4 hu4 hou2
|
|
||||||
一方 yi4 fang1
|
|
||||||
一番话 yi4 fan1 hua4
|
|
||||||
一遍 yi2 bian4
|
|
||||||
不计较 bu2 ji4 jiao4
|
|
||||||
诇 xiong4
|
|
||||||
一边 yi4 bian1
|
|
||||||
一束 yi2 shu4
|
|
||||||
一听到 yi4 ting1 dao4
|
|
||||||
炸鸡 zha2 ji1
|
|
||||||
乍暧还寒 zha4 ai4 huan2 han2
|
|
||||||
我说诶 wo3 shuo1 ei1
|
|
||||||
棒诶 bang4 ei1
|
|
||||||
寒碜 han2 chen4
|
|
||||||
应采儿 ying4 cai3 er2
|
|
||||||
晕车 yun1 che1
|
|
||||||
必应 bi4 ying4
|
|
||||||
应援 ying4 yuan2
|
|
||||||
应力 ying4 li4
|
|
@ -1,73 +0,0 @@
|
|||||||
# This file is copied from
|
|
||||||
# https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py
|
|
||||||
_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
|
|
||||||
|
|
||||||
_initials = [
|
|
||||||
"^",
|
|
||||||
"b",
|
|
||||||
"c",
|
|
||||||
"ch",
|
|
||||||
"d",
|
|
||||||
"f",
|
|
||||||
"g",
|
|
||||||
"h",
|
|
||||||
"j",
|
|
||||||
"k",
|
|
||||||
"l",
|
|
||||||
"m",
|
|
||||||
"n",
|
|
||||||
"p",
|
|
||||||
"q",
|
|
||||||
"r",
|
|
||||||
"s",
|
|
||||||
"sh",
|
|
||||||
"t",
|
|
||||||
"x",
|
|
||||||
"z",
|
|
||||||
"zh",
|
|
||||||
]
|
|
||||||
|
|
||||||
_tones = ["1", "2", "3", "4", "5"]
|
|
||||||
|
|
||||||
_finals = [
|
|
||||||
"a",
|
|
||||||
"ai",
|
|
||||||
"an",
|
|
||||||
"ang",
|
|
||||||
"ao",
|
|
||||||
"e",
|
|
||||||
"ei",
|
|
||||||
"en",
|
|
||||||
"eng",
|
|
||||||
"er",
|
|
||||||
"i",
|
|
||||||
"ia",
|
|
||||||
"ian",
|
|
||||||
"iang",
|
|
||||||
"iao",
|
|
||||||
"ie",
|
|
||||||
"ii",
|
|
||||||
"iii",
|
|
||||||
"in",
|
|
||||||
"ing",
|
|
||||||
"iong",
|
|
||||||
"iou",
|
|
||||||
"o",
|
|
||||||
"ong",
|
|
||||||
"ou",
|
|
||||||
"u",
|
|
||||||
"ua",
|
|
||||||
"uai",
|
|
||||||
"uan",
|
|
||||||
"uang",
|
|
||||||
"uei",
|
|
||||||
"uen",
|
|
||||||
"ueng",
|
|
||||||
"uo",
|
|
||||||
"v",
|
|
||||||
"van",
|
|
||||||
"ve",
|
|
||||||
"vn",
|
|
||||||
]
|
|
||||||
|
|
||||||
symbols = _pause + _initials + [i + j for i in _finals for j in _tones]
|
|
@ -1,137 +0,0 @@
|
|||||||
# This file is modified from
|
|
||||||
# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
# Note pinyin_dict is from ./pinyin_dict.py
|
|
||||||
from pinyin_dict import pinyin_dict
|
|
||||||
from pypinyin import Style
|
|
||||||
from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin
|
|
||||||
from pypinyin.converter import DefaultConverter
|
|
||||||
from pypinyin.core import Pinyin, load_phrases_dict
|
|
||||||
|
|
||||||
|
|
||||||
class _MyConverter(NeutralToneWith5Mixin, DefaultConverter):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
|
||||||
def __init__(self, tokens: str = ""):
|
|
||||||
self._load_pinyin_dict()
|
|
||||||
self._pinyin_parser = Pinyin(_MyConverter())
|
|
||||||
|
|
||||||
if tokens != "":
|
|
||||||
self._load_tokens(tokens)
|
|
||||||
|
|
||||||
def texts_to_token_ids(self, texts: List[str], **kwargs) -> List[List[int]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
texts:
|
|
||||||
A list of sentences.
|
|
||||||
kwargs:
|
|
||||||
Not used. It is for compatibility with other TTS recipes in icefall.
|
|
||||||
"""
|
|
||||||
tokens = []
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
tokens.append(self.text_to_tokens(text))
|
|
||||||
|
|
||||||
return self.tokens_to_token_ids(tokens)
|
|
||||||
|
|
||||||
def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]:
|
|
||||||
ans = []
|
|
||||||
|
|
||||||
for token_list in tokens:
|
|
||||||
token_ids = []
|
|
||||||
for t in token_list:
|
|
||||||
if t not in self.token2id:
|
|
||||||
logging.warning(f"Skip OOV {t}")
|
|
||||||
continue
|
|
||||||
token_ids.append(self.token2id[t])
|
|
||||||
ans.append(token_ids)
|
|
||||||
|
|
||||||
return ans
|
|
||||||
|
|
||||||
def text_to_tokens(self, text: str) -> List[str]:
|
|
||||||
# Convert "," to ["sp", "sil"]
|
|
||||||
# Convert "。" to ["sil"]
|
|
||||||
# append ["eos"] at the end of a sentence
|
|
||||||
phonemes = ["sil"]
|
|
||||||
pinyins = self._pinyin_parser.pinyin(
|
|
||||||
text,
|
|
||||||
style=Style.TONE3,
|
|
||||||
errors=lambda x: [[w] for w in x],
|
|
||||||
)
|
|
||||||
|
|
||||||
new_pinyin = []
|
|
||||||
for p in pinyins:
|
|
||||||
p = p[0]
|
|
||||||
if p == ",":
|
|
||||||
new_pinyin.extend(["sp", "sil"])
|
|
||||||
elif p == "。":
|
|
||||||
new_pinyin.append("sil")
|
|
||||||
else:
|
|
||||||
new_pinyin.append(p)
|
|
||||||
sub_phonemes = self._get_phoneme4pinyin(new_pinyin)
|
|
||||||
sub_phonemes.append("eos")
|
|
||||||
phonemes.extend(sub_phonemes)
|
|
||||||
return phonemes
|
|
||||||
|
|
||||||
def _get_phoneme4pinyin(self, pinyins):
|
|
||||||
result = []
|
|
||||||
for pinyin in pinyins:
|
|
||||||
if pinyin in ("sil", "sp"):
|
|
||||||
result.append(pinyin)
|
|
||||||
elif pinyin[:-1] in pinyin_dict:
|
|
||||||
tone = pinyin[-1]
|
|
||||||
a = pinyin[:-1]
|
|
||||||
a1, a2 = pinyin_dict[a]
|
|
||||||
# every word is appended with a #0
|
|
||||||
result += [a1, a2 + tone, "#0"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _load_pinyin_dict(self):
|
|
||||||
this_dir = Path(__file__).parent.resolve()
|
|
||||||
my_dict = {}
|
|
||||||
with open(f"{this_dir}/pypinyin-local.dict", "r", encoding="utf-8") as f:
|
|
||||||
content = f.readlines()
|
|
||||||
for line in content:
|
|
||||||
cuts = line.strip().split()
|
|
||||||
hanzi = cuts[0]
|
|
||||||
pinyin = cuts[1:]
|
|
||||||
my_dict[hanzi] = [[p] for p in pinyin]
|
|
||||||
|
|
||||||
load_phrases_dict(my_dict)
|
|
||||||
|
|
||||||
def _load_tokens(self, filename):
|
|
||||||
token2id: Dict[str, int] = {}
|
|
||||||
|
|
||||||
with open(filename, "r", encoding="utf-8") as f:
|
|
||||||
for line in f.readlines():
|
|
||||||
info = line.rstrip().split()
|
|
||||||
if len(info) == 1:
|
|
||||||
# case of space
|
|
||||||
token = " "
|
|
||||||
idx = int(info[0])
|
|
||||||
else:
|
|
||||||
token, idx = info[0], int(info[1])
|
|
||||||
|
|
||||||
assert token not in token2id, token
|
|
||||||
|
|
||||||
token2id[token] = idx
|
|
||||||
|
|
||||||
self.token2id = token2id
|
|
||||||
self.vocab_size = len(self.token2id)
|
|
||||||
self.pad_id = self.token2id["#0"]
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
tokenizer = Tokenizer()
|
|
||||||
tokenizer._sentence_to_ids("你好,好的。")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/local/validate_manifest.py
|
|
@ -1,124 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
|
||||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|
||||||
|
|
||||||
set -eou pipefail
|
|
||||||
|
|
||||||
stage=-1
|
|
||||||
stop_stage=100
|
|
||||||
|
|
||||||
dl_dir=$PWD/download
|
|
||||||
|
|
||||||
. shared/parse_options.sh || exit 1
|
|
||||||
|
|
||||||
# All files generated by this script are saved in "data".
|
|
||||||
# You can safely remove "data" and rerun this script to regenerate it.
|
|
||||||
mkdir -p data
|
|
||||||
|
|
||||||
log() {
|
|
||||||
# This function is from espnet
|
|
||||||
local fname=${BASH_SOURCE[1]##*/}
|
|
||||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
|
||||||
}
|
|
||||||
|
|
||||||
log "dl_dir: $dl_dir"
|
|
||||||
|
|
||||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|
||||||
log "Stage 0: build monotonic_align lib"
|
|
||||||
if [ ! -d vits/monotonic_align/build ]; then
|
|
||||||
cd vits/monotonic_align
|
|
||||||
python3 setup.py build_ext --inplace
|
|
||||||
cd ../../
|
|
||||||
else
|
|
||||||
log "monotonic_align lib already built"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
|
||||||
log "Stage 1: Download data"
|
|
||||||
|
|
||||||
# The directory $dl_dir/BZNSYP will contain 3 sub directories:
|
|
||||||
# - PhoneLabeling
|
|
||||||
# - ProsodyLabeling
|
|
||||||
# - Wave
|
|
||||||
|
|
||||||
# If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink
|
|
||||||
#
|
|
||||||
# ln -sfv /path/to/BZNSYP $dl_dir/
|
|
||||||
# touch $dl_dir/BZNSYP/.completed
|
|
||||||
#
|
|
||||||
if [ ! -d $dl_dir/BZNSYP ]; then
|
|
||||||
lhotse download baker-zh $dl_dir
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|
||||||
log "Stage 2: Prepare baker-zh manifest"
|
|
||||||
# We assume that you have downloaded the baker corpus
|
|
||||||
# to $dl_dir/BZNSYP
|
|
||||||
mkdir -p data/manifests
|
|
||||||
if [ ! -e data/manifests/.baker.done ]; then
|
|
||||||
lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests
|
|
||||||
touch data/manifests/.baker.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|
||||||
log "Stage 3: Compute spectrogram for baker (may take 3 minutes)"
|
|
||||||
mkdir -p data/spectrogram
|
|
||||||
if [ ! -e data/spectrogram/.baker.done ]; then
|
|
||||||
./local/compute_spectrogram_baker.py
|
|
||||||
touch data/spectrogram/.baker.done
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -e data/spectrogram/.baker-validated.done ]; then
|
|
||||||
log "Validating data/spectrogram for baker"
|
|
||||||
python3 ./local/validate_manifest.py \
|
|
||||||
data/spectrogram/baker_zh_cuts_all.jsonl.gz
|
|
||||||
touch data/spectrogram/.baker-validated.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|
||||||
log "Stage 4: Prepare tokens for baker-zh (may take 20 seconds)"
|
|
||||||
if [ ! -e data/spectrogram/.baker_zh_with_token.done ]; then
|
|
||||||
|
|
||||||
./local/prepare_tokens_baker_zh.py
|
|
||||||
|
|
||||||
mv -v data/spectrogram/baker_zh_cuts_with_tokens_all.jsonl.gz \
|
|
||||||
data/spectrogram/baker_zh_cuts_all.jsonl.gz
|
|
||||||
|
|
||||||
touch data/spectrogram/.baker_zh_with_token.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|
||||||
log "Stage 5: Split the baker-zh cuts into train, valid and test sets (may take 25 seconds)"
|
|
||||||
if [ ! -e data/spectrogram/.baker_zh_split.done ]; then
|
|
||||||
lhotse subset --last 600 \
|
|
||||||
data/spectrogram/baker_zh_cuts_all.jsonl.gz \
|
|
||||||
data/spectrogram/baker_zh_cuts_validtest.jsonl.gz
|
|
||||||
lhotse subset --first 100 \
|
|
||||||
data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \
|
|
||||||
data/spectrogram/baker_zh_cuts_valid.jsonl.gz
|
|
||||||
lhotse subset --last 500 \
|
|
||||||
data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \
|
|
||||||
data/spectrogram/baker_zh_cuts_test.jsonl.gz
|
|
||||||
|
|
||||||
rm data/spectrogram/baker_zh_cuts_validtest.jsonl.gz
|
|
||||||
|
|
||||||
n=$(( $(gunzip -c data/spectrogram/baker_zh_cuts_all.jsonl.gz | wc -l) - 600 ))
|
|
||||||
lhotse subset --first $n \
|
|
||||||
data/spectrogram/baker_zh_cuts_all.jsonl.gz \
|
|
||||||
data/spectrogram/baker_zh_cuts_train.jsonl.gz
|
|
||||||
touch data/spectrogram/.baker_zh_split.done
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|
||||||
log "Stage 6: Generate token file"
|
|
||||||
if [ ! -e data/tokens.txt ]; then
|
|
||||||
./local/prepare_token_file.py --tokens data/tokens.txt
|
|
||||||
fi
|
|
||||||
fi
|
|
@ -1 +0,0 @@
|
|||||||
../../../icefall/shared
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/duration_predictor.py
|
|
@ -1,414 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
#
|
|
||||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
This script exports a VITS model from PyTorch to ONNX.
|
|
||||||
|
|
||||||
Export the model to ONNX:
|
|
||||||
./vits/export-onnx.py \
|
|
||||||
--epoch 1000 \
|
|
||||||
--exp-dir vits/exp \
|
|
||||||
--tokens data/tokens.txt
|
|
||||||
|
|
||||||
It will generate one file inside vits/exp:
|
|
||||||
- vits-epoch-1000.onnx
|
|
||||||
|
|
||||||
See ./test_onnx.py for how to use the exported ONNX models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Tuple
|
|
||||||
|
|
||||||
import onnx
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from tokenizer import Tokenizer
|
|
||||||
from train import get_model, get_params
|
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--epoch",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
|
||||||
Note: Epoch counts from 1.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--exp-dir",
|
|
||||||
type=str,
|
|
||||||
default="vits/exp",
|
|
||||||
help="The experiment dir",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokens",
|
|
||||||
type=str,
|
|
||||||
default="data/tokens.txt",
|
|
||||||
help="""Path to vocabulary.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--model-type",
|
|
||||||
type=str,
|
|
||||||
default="high",
|
|
||||||
choices=["low", "medium", "high"],
|
|
||||||
help="""If not empty, valid values are: low, medium, high.
|
|
||||||
It controls the model size. low -> runs faster.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
|
||||||
"""Add meta data to an ONNX model. It is changed in-place.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename:
|
|
||||||
Filename of the ONNX model to be changed.
|
|
||||||
meta_data:
|
|
||||||
Key-value pairs.
|
|
||||||
"""
|
|
||||||
model = onnx.load(filename)
|
|
||||||
for key, value in meta_data.items():
|
|
||||||
meta = model.metadata_props.add()
|
|
||||||
meta.key = key
|
|
||||||
meta.value = str(value)
|
|
||||||
|
|
||||||
onnx.save(model, filename)
|
|
||||||
|
|
||||||
|
|
||||||
class OnnxModel(nn.Module):
|
|
||||||
"""A wrapper for VITS generator."""
|
|
||||||
|
|
||||||
def __init__(self, model: nn.Module):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
model:
|
|
||||||
A VITS generator.
|
|
||||||
frame_shift:
|
|
||||||
The frame shift in samples.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
tokens: torch.Tensor,
|
|
||||||
tokens_lens: torch.Tensor,
|
|
||||||
noise_scale: float = 0.667,
|
|
||||||
alpha: float = 1.0,
|
|
||||||
noise_scale_dur: float = 0.8,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Please see the help information of VITS.inference_batch
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens:
|
|
||||||
Input text token indexes (1, T_text)
|
|
||||||
tokens_lens:
|
|
||||||
Number of tokens of shape (1,)
|
|
||||||
noise_scale (float):
|
|
||||||
Noise scale parameter for flow.
|
|
||||||
noise_scale_dur (float):
|
|
||||||
Noise scale parameter for duration predictor.
|
|
||||||
alpha (float):
|
|
||||||
Alpha parameter to control the speed of generated speech.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a tuple containing:
|
|
||||||
- audio, generated wavform tensor, (B, T_wav)
|
|
||||||
"""
|
|
||||||
audio, _, _ = self.model.generator.inference(
|
|
||||||
text=tokens,
|
|
||||||
text_lengths=tokens_lens,
|
|
||||||
noise_scale=noise_scale,
|
|
||||||
noise_scale_dur=noise_scale_dur,
|
|
||||||
alpha=alpha,
|
|
||||||
)
|
|
||||||
return audio
|
|
||||||
|
|
||||||
|
|
||||||
def export_model_onnx(
|
|
||||||
model: nn.Module,
|
|
||||||
model_filename: str,
|
|
||||||
vocab_size: int,
|
|
||||||
opset_version: int = 11,
|
|
||||||
) -> None:
|
|
||||||
"""Export the given generator model to ONNX format.
|
|
||||||
The exported model has one input:
|
|
||||||
|
|
||||||
- tokens, a tensor of shape (1, T_text); dtype is torch.int64
|
|
||||||
|
|
||||||
and it has one output:
|
|
||||||
|
|
||||||
- audio, a tensor of shape (1, T'); dtype is torch.float32
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model:
|
|
||||||
The VITS generator.
|
|
||||||
model_filename:
|
|
||||||
The filename to save the exported ONNX model.
|
|
||||||
vocab_size:
|
|
||||||
Number of tokens used in training.
|
|
||||||
opset_version:
|
|
||||||
The opset version to use.
|
|
||||||
"""
|
|
||||||
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
|
|
||||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
|
|
||||||
noise_scale = torch.tensor([1], dtype=torch.float32)
|
|
||||||
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
|
||||||
alpha = torch.tensor([1], dtype=torch.float32)
|
|
||||||
|
|
||||||
torch.onnx.export(
|
|
||||||
model,
|
|
||||||
(tokens, tokens_lens, noise_scale, alpha, noise_scale_dur),
|
|
||||||
model_filename,
|
|
||||||
verbose=False,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=[
|
|
||||||
"tokens",
|
|
||||||
"tokens_lens",
|
|
||||||
"noise_scale",
|
|
||||||
"alpha",
|
|
||||||
"noise_scale_dur",
|
|
||||||
],
|
|
||||||
output_names=["audio"],
|
|
||||||
dynamic_axes={
|
|
||||||
"tokens": {0: "N", 1: "T"},
|
|
||||||
"tokens_lens": {0: "N"},
|
|
||||||
"audio": {0: "N", 1: "T"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if model.model.spks is None:
|
|
||||||
num_speakers = 1
|
|
||||||
else:
|
|
||||||
num_speakers = model.model.spks
|
|
||||||
|
|
||||||
meta_data = {
|
|
||||||
"model_type": "vits",
|
|
||||||
"version": "1",
|
|
||||||
"model_author": "k2-fsa",
|
|
||||||
"comment": "icefall", # must be icefall for models from icefall
|
|
||||||
"language": "Chinese",
|
|
||||||
"n_speakers": num_speakers,
|
|
||||||
"sample_rate": model.model.sampling_rate, # Must match the real sample rate
|
|
||||||
}
|
|
||||||
logging.info(f"meta_data: {meta_data}")
|
|
||||||
|
|
||||||
add_meta_data(filename=model_filename, meta_data=meta_data)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def main():
|
|
||||||
args = get_parser().parse_args()
|
|
||||||
args.exp_dir = Path(args.exp_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
|
||||||
params.update(vars(args))
|
|
||||||
|
|
||||||
tokenizer = Tokenizer(params.tokens)
|
|
||||||
params.blank_id = tokenizer.pad_id
|
|
||||||
params.vocab_size = tokenizer.vocab_size
|
|
||||||
|
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
logging.info("About to create model")
|
|
||||||
model = get_model(params)
|
|
||||||
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
||||||
|
|
||||||
model.to("cpu")
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
model = OnnxModel(model=model)
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
|
|
||||||
|
|
||||||
suffix = f"epoch-{params.epoch}"
|
|
||||||
|
|
||||||
opset_version = 13
|
|
||||||
|
|
||||||
logging.info("Exporting encoder")
|
|
||||||
model_filename = params.exp_dir / f"vits-{suffix}.onnx"
|
|
||||||
export_model_onnx(
|
|
||||||
model,
|
|
||||||
model_filename,
|
|
||||||
params.vocab_size,
|
|
||||||
opset_version=opset_version,
|
|
||||||
)
|
|
||||||
logging.info(f"Exported generator to {model_filename}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
main()
|
|
||||||
|
|
||||||
"""
|
|
||||||
Supported languages.
|
|
||||||
|
|
||||||
LJSpeech is using "en-us" from the second column.
|
|
||||||
|
|
||||||
Pty Language Age/Gender VoiceName File Other Languages
|
|
||||||
5 af --/M Afrikaans gmw/af
|
|
||||||
5 am --/M Amharic sem/am
|
|
||||||
5 an --/M Aragonese roa/an
|
|
||||||
5 ar --/M Arabic sem/ar
|
|
||||||
5 as --/M Assamese inc/as
|
|
||||||
5 az --/M Azerbaijani trk/az
|
|
||||||
5 ba --/M Bashkir trk/ba
|
|
||||||
5 be --/M Belarusian zle/be
|
|
||||||
5 bg --/M Bulgarian zls/bg
|
|
||||||
5 bn --/M Bengali inc/bn
|
|
||||||
5 bpy --/M Bishnupriya_Manipuri inc/bpy
|
|
||||||
5 bs --/M Bosnian zls/bs
|
|
||||||
5 ca --/M Catalan roa/ca
|
|
||||||
5 chr-US-Qaaa-x-west --/M Cherokee_ iro/chr
|
|
||||||
5 cmn --/M Chinese_(Mandarin,_latin_as_English) sit/cmn (zh-cmn 5)(zh 5)
|
|
||||||
5 cmn-latn-pinyin --/M Chinese_(Mandarin,_latin_as_Pinyin) sit/cmn-Latn-pinyin (zh-cmn 5)(zh 5)
|
|
||||||
5 cs --/M Czech zlw/cs
|
|
||||||
5 cv --/M Chuvash trk/cv
|
|
||||||
5 cy --/M Welsh cel/cy
|
|
||||||
5 da --/M Danish gmq/da
|
|
||||||
5 de --/M German gmw/de
|
|
||||||
5 el --/M Greek grk/el
|
|
||||||
5 en-029 --/M English_(Caribbean) gmw/en-029 (en 10)
|
|
||||||
2 en-gb --/M English_(Great_Britain) gmw/en (en 2)
|
|
||||||
5 en-gb-scotland --/M English_(Scotland) gmw/en-GB-scotland (en 4)
|
|
||||||
5 en-gb-x-gbclan --/M English_(Lancaster) gmw/en-GB-x-gbclan (en-gb 3)(en 5)
|
|
||||||
5 en-gb-x-gbcwmd --/M English_(West_Midlands) gmw/en-GB-x-gbcwmd (en-gb 9)(en 9)
|
|
||||||
5 en-gb-x-rp --/M English_(Received_Pronunciation) gmw/en-GB-x-rp (en-gb 4)(en 5)
|
|
||||||
2 en-us --/M English_(America) gmw/en-US (en 3)
|
|
||||||
5 en-us-nyc --/M English_(America,_New_York_City) gmw/en-US-nyc
|
|
||||||
5 eo --/M Esperanto art/eo
|
|
||||||
5 es --/M Spanish_(Spain) roa/es
|
|
||||||
5 es-419 --/M Spanish_(Latin_America) roa/es-419 (es-mx 6)
|
|
||||||
5 et --/M Estonian urj/et
|
|
||||||
5 eu --/M Basque eu
|
|
||||||
5 fa --/M Persian ira/fa
|
|
||||||
5 fa-latn --/M Persian_(Pinglish) ira/fa-Latn
|
|
||||||
5 fi --/M Finnish urj/fi
|
|
||||||
5 fr-be --/M French_(Belgium) roa/fr-BE (fr 8)
|
|
||||||
5 fr-ch --/M French_(Switzerland) roa/fr-CH (fr 8)
|
|
||||||
5 fr-fr --/M French_(France) roa/fr (fr 5)
|
|
||||||
5 ga --/M Gaelic_(Irish) cel/ga
|
|
||||||
5 gd --/M Gaelic_(Scottish) cel/gd
|
|
||||||
5 gn --/M Guarani sai/gn
|
|
||||||
5 grc --/M Greek_(Ancient) grk/grc
|
|
||||||
5 gu --/M Gujarati inc/gu
|
|
||||||
5 hak --/M Hakka_Chinese sit/hak
|
|
||||||
5 haw --/M Hawaiian map/haw
|
|
||||||
5 he --/M Hebrew sem/he
|
|
||||||
5 hi --/M Hindi inc/hi
|
|
||||||
5 hr --/M Croatian zls/hr (hbs 5)
|
|
||||||
5 ht --/M Haitian_Creole roa/ht
|
|
||||||
5 hu --/M Hungarian urj/hu
|
|
||||||
5 hy --/M Armenian_(East_Armenia) ine/hy (hy-arevela 5)
|
|
||||||
5 hyw --/M Armenian_(West_Armenia) ine/hyw (hy-arevmda 5)(hy 8)
|
|
||||||
5 ia --/M Interlingua art/ia
|
|
||||||
5 id --/M Indonesian poz/id
|
|
||||||
5 io --/M Ido art/io
|
|
||||||
5 is --/M Icelandic gmq/is
|
|
||||||
5 it --/M Italian roa/it
|
|
||||||
5 ja --/M Japanese jpx/ja
|
|
||||||
5 jbo --/M Lojban art/jbo
|
|
||||||
5 ka --/M Georgian ccs/ka
|
|
||||||
5 kk --/M Kazakh trk/kk
|
|
||||||
5 kl --/M Greenlandic esx/kl
|
|
||||||
5 kn --/M Kannada dra/kn
|
|
||||||
5 ko --/M Korean ko
|
|
||||||
5 kok --/M Konkani inc/kok
|
|
||||||
5 ku --/M Kurdish ira/ku
|
|
||||||
5 ky --/M Kyrgyz trk/ky
|
|
||||||
5 la --/M Latin itc/la
|
|
||||||
5 lb --/M Luxembourgish gmw/lb
|
|
||||||
5 lfn --/M Lingua_Franca_Nova art/lfn
|
|
||||||
5 lt --/M Lithuanian bat/lt
|
|
||||||
5 ltg --/M Latgalian bat/ltg
|
|
||||||
5 lv --/M Latvian bat/lv
|
|
||||||
5 mi --/M Māori poz/mi
|
|
||||||
5 mk --/M Macedonian zls/mk
|
|
||||||
5 ml --/M Malayalam dra/ml
|
|
||||||
5 mr --/M Marathi inc/mr
|
|
||||||
5 ms --/M Malay poz/ms
|
|
||||||
5 mt --/M Maltese sem/mt
|
|
||||||
5 mto --/M Totontepec_Mixe miz/mto
|
|
||||||
5 my --/M Myanmar_(Burmese) sit/my
|
|
||||||
5 nb --/M Norwegian_Bokmål gmq/nb (no 5)
|
|
||||||
5 nci --/M Nahuatl_(Classical) azc/nci
|
|
||||||
5 ne --/M Nepali inc/ne
|
|
||||||
5 nl --/M Dutch gmw/nl
|
|
||||||
5 nog --/M Nogai trk/nog
|
|
||||||
5 om --/M Oromo cus/om
|
|
||||||
5 or --/M Oriya inc/or
|
|
||||||
5 pa --/M Punjabi inc/pa
|
|
||||||
5 pap --/M Papiamento roa/pap
|
|
||||||
5 piqd --/M Klingon art/piqd
|
|
||||||
5 pl --/M Polish zlw/pl
|
|
||||||
5 pt --/M Portuguese_(Portugal) roa/pt (pt-pt 5)
|
|
||||||
5 pt-br --/M Portuguese_(Brazil) roa/pt-BR (pt 6)
|
|
||||||
5 py --/M Pyash art/py
|
|
||||||
5 qdb --/M Lang_Belta art/qdb
|
|
||||||
5 qu --/M Quechua qu
|
|
||||||
5 quc --/M K'iche' myn/quc
|
|
||||||
5 qya --/M Quenya art/qya
|
|
||||||
5 ro --/M Romanian roa/ro
|
|
||||||
5 ru --/M Russian zle/ru
|
|
||||||
5 ru-cl --/M Russian_(Classic) zle/ru-cl
|
|
||||||
2 ru-lv --/M Russian_(Latvia) zle/ru-LV
|
|
||||||
5 sd --/M Sindhi inc/sd
|
|
||||||
5 shn --/M Shan_(Tai_Yai) tai/shn
|
|
||||||
5 si --/M Sinhala inc/si
|
|
||||||
5 sjn --/M Sindarin art/sjn
|
|
||||||
5 sk --/M Slovak zlw/sk
|
|
||||||
5 sl --/M Slovenian zls/sl
|
|
||||||
5 smj --/M Lule_Saami urj/smj
|
|
||||||
5 sq --/M Albanian ine/sq
|
|
||||||
5 sr --/M Serbian zls/sr
|
|
||||||
5 sv --/M Swedish gmq/sv
|
|
||||||
5 sw --/M Swahili bnt/sw
|
|
||||||
5 ta --/M Tamil dra/ta
|
|
||||||
5 te --/M Telugu dra/te
|
|
||||||
5 th --/M Thai tai/th
|
|
||||||
5 tk --/M Turkmen trk/tk
|
|
||||||
5 tn --/M Setswana bnt/tn
|
|
||||||
5 tr --/M Turkish trk/tr
|
|
||||||
5 tt --/M Tatar trk/tt
|
|
||||||
5 ug --/M Uyghur trk/ug
|
|
||||||
5 uk --/M Ukrainian zle/uk
|
|
||||||
5 ur --/M Urdu inc/ur
|
|
||||||
5 uz --/M Uzbek trk/uz
|
|
||||||
5 vi --/M Vietnamese_(Northern) aav/vi
|
|
||||||
5 vi-vn-x-central --/M Vietnamese_(Central) aav/vi-VN-x-central
|
|
||||||
5 vi-vn-x-south --/M Vietnamese_(Southern) aav/vi-VN-x-south
|
|
||||||
5 yue --/M Chinese_(Cantonese) sit/yue (zh-yue 5)(zh 8)
|
|
||||||
5 yue --/M Chinese_(Cantonese,_latin_as_Jyutping) sit/yue-Latn-jyutping (zh-yue 5)(zh 8)
|
|
||||||
"""
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/flow.py
|
|
@ -1,39 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
from pypinyin import phrases_dict, pinyin_dict
|
|
||||||
from tokenizer import Tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
filename = "lexicon.txt"
|
|
||||||
tokens = "./data/tokens.txt"
|
|
||||||
tokenizer = Tokenizer(tokens)
|
|
||||||
|
|
||||||
word_dict = pinyin_dict.pinyin_dict
|
|
||||||
phrases = phrases_dict.phrases_dict
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
with open(filename, "w", encoding="utf-8") as f:
|
|
||||||
for key in word_dict:
|
|
||||||
if not (0x4E00 <= key <= 0x9FFF):
|
|
||||||
continue
|
|
||||||
|
|
||||||
w = chr(key)
|
|
||||||
|
|
||||||
# 1 to remove the initial sil
|
|
||||||
# :-1 to remove the final eos
|
|
||||||
tokens = tokenizer.text_to_tokens(w)[1:-1]
|
|
||||||
|
|
||||||
tokens = " ".join(tokens)
|
|
||||||
f.write(f"{w} {tokens}\n")
|
|
||||||
|
|
||||||
for key in phrases:
|
|
||||||
# 1 to remove the initial sil
|
|
||||||
# :-1 to remove the final eos
|
|
||||||
tokens = tokenizer.text_to_tokens(key)[1:-1]
|
|
||||||
tokens = " ".join(tokens)
|
|
||||||
f.write(f"{key} {tokens}\n")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/generator.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/hifigan.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/loss.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/monotonic_align
|
|
@ -1 +0,0 @@
|
|||||||
../local/pinyin_dict.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/posterior_encoder.py
|
|
@ -1 +0,0 @@
|
|||||||
../local/pypinyin-local.dict
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/residual_coupling.py
|
|
@ -1,142 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
#
|
|
||||||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
This script is used to test the exported onnx model by vits/export-onnx.py
|
|
||||||
|
|
||||||
Use the onnx model to generate a wav:
|
|
||||||
./vits/test_onnx.py \
|
|
||||||
--model-filename vits/exp/vits-epoch-1000.onnx \
|
|
||||||
--tokens data/tokens.txt
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import onnxruntime as ort
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from tokenizer import Tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--model-filename",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the onnx model.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokens",
|
|
||||||
type=str,
|
|
||||||
default="data/tokens.txt",
|
|
||||||
help="""Path to vocabulary.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--text",
|
|
||||||
type=str,
|
|
||||||
default="Ask not what your country can do for you; ask what you can do for your country.",
|
|
||||||
help="Text to generate speech for",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-filename",
|
|
||||||
type=str,
|
|
||||||
default="test_onnx.wav",
|
|
||||||
help="Filename to save the generated wave file.",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
class OnnxModel:
|
|
||||||
def __init__(self, model_filename: str):
|
|
||||||
session_opts = ort.SessionOptions()
|
|
||||||
session_opts.inter_op_num_threads = 1
|
|
||||||
session_opts.intra_op_num_threads = 1
|
|
||||||
|
|
||||||
self.session_opts = session_opts
|
|
||||||
|
|
||||||
self.model = ort.InferenceSession(
|
|
||||||
model_filename,
|
|
||||||
sess_options=self.session_opts,
|
|
||||||
providers=["CPUExecutionProvider"],
|
|
||||||
)
|
|
||||||
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
|
||||||
|
|
||||||
metadata = self.model.get_modelmeta().custom_metadata_map
|
|
||||||
self.sample_rate = int(metadata["sample_rate"])
|
|
||||||
|
|
||||||
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
tokens:
|
|
||||||
A 1-D tensor of shape (1, T)
|
|
||||||
Returns:
|
|
||||||
A tensor of shape (1, T')
|
|
||||||
"""
|
|
||||||
noise_scale = torch.tensor([0.667], dtype=torch.float32)
|
|
||||||
noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
|
|
||||||
alpha = torch.tensor([1.0], dtype=torch.float32)
|
|
||||||
|
|
||||||
out = self.model.run(
|
|
||||||
[
|
|
||||||
self.model.get_outputs()[0].name,
|
|
||||||
],
|
|
||||||
{
|
|
||||||
self.model.get_inputs()[0].name: tokens.numpy(),
|
|
||||||
self.model.get_inputs()[1].name: tokens_lens.numpy(),
|
|
||||||
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
|
||||||
self.model.get_inputs()[3].name: alpha.numpy(),
|
|
||||||
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
|
|
||||||
},
|
|
||||||
)[0]
|
|
||||||
return torch.from_numpy(out)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = get_parser().parse_args()
|
|
||||||
logging.info(vars(args))
|
|
||||||
|
|
||||||
tokenizer = Tokenizer(args.tokens)
|
|
||||||
|
|
||||||
logging.info("About to create onnx model")
|
|
||||||
model = OnnxModel(args.model_filename)
|
|
||||||
|
|
||||||
text = args.text
|
|
||||||
tokens = tokenizer.texts_to_token_ids([text])
|
|
||||||
tokens = torch.tensor(tokens) # (1, T)
|
|
||||||
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
|
||||||
audio = model(tokens, tokens_lens) # (1, T')
|
|
||||||
|
|
||||||
output_filename = args.output_filename
|
|
||||||
torchaudio.save(output_filename, audio, sample_rate=model.sample_rate)
|
|
||||||
logging.info(f"Saved to {output_filename}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
main()
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/text_encoder.py
|
|
@ -1 +0,0 @@
|
|||||||
../local/tokenizer.py
|
|
@ -1,927 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from shutil import copyfile
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch.nn as nn
|
|
||||||
from lhotse.cut import Cut
|
|
||||||
from lhotse.utils import fix_random_seed
|
|
||||||
from tokenizer import Tokenizer
|
|
||||||
from torch.cuda.amp import GradScaler, autocast
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
from tts_datamodule import BakerZhSpeechTtsDataModule
|
|
||||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
|
||||||
from vits import VITS
|
|
||||||
|
|
||||||
from icefall import diagnostics
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.hooks import register_inf_check_hooks
|
|
||||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
|
||||||
|
|
||||||
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--world-size",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of GPUs for DDP training.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--master-port",
|
|
||||||
type=int,
|
|
||||||
default=12354,
|
|
||||||
help="Master port to use for DDP training.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--tensorboard",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Should various information be logged in tensorboard.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-epochs",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Number of epochs to train.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--start-epoch",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="""Resume training from this epoch. It should be positive.
|
|
||||||
If larger than 1, it will load checkpoint from
|
|
||||||
exp-dir/epoch-{start_epoch-1}.pt
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--exp-dir",
|
|
||||||
type=str,
|
|
||||||
default="vits/exp",
|
|
||||||
help="""The experiment dir.
|
|
||||||
It specifies the directory where all training related
|
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokens",
|
|
||||||
type=str,
|
|
||||||
default="data/tokens.txt",
|
|
||||||
help="""Path to vocabulary.""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lr", type=float, default=2.0e-4, help="The base learning rate."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="The seed for random generators intended for reproducibility",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--print-diagnostics",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Accumulate stats on activations, print them and exit.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--inf-check",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Add hooks to check for infinite module outputs and gradients.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--save-every-n",
|
|
||||||
type=int,
|
|
||||||
default=20,
|
|
||||||
help="""Save checkpoint after processing this number of epochs"
|
|
||||||
periodically. We save checkpoint to exp-dir/ whenever
|
|
||||||
params.cur_epoch % save_every_n == 0. The checkpoint filename
|
|
||||||
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
|
|
||||||
Since it will take around 1000 epochs, we suggest using a large
|
|
||||||
save_every_n to save disk space.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-fp16",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="Whether to use half precision training.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--model-type",
|
|
||||||
type=str,
|
|
||||||
default="high",
|
|
||||||
choices=["low", "medium", "high"],
|
|
||||||
help="""If not empty, valid values are: low, medium, high.
|
|
||||||
It controls the model size. low -> runs faster.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def get_params() -> AttributeDict:
|
|
||||||
"""Return a dict containing training parameters.
|
|
||||||
|
|
||||||
All training related parameters that are not passed from the commandline
|
|
||||||
are saved in the variable `params`.
|
|
||||||
|
|
||||||
Commandline options are merged into `params` after they are parsed, so
|
|
||||||
you can also access them via `params`.
|
|
||||||
|
|
||||||
Explanation of options saved in `params`:
|
|
||||||
|
|
||||||
- best_train_loss: Best training loss so far. It is used to select
|
|
||||||
the model that has the lowest training loss. It is
|
|
||||||
updated during the training.
|
|
||||||
|
|
||||||
- best_valid_loss: Best validation loss so far. It is used to select
|
|
||||||
the model that has the lowest validation loss. It is
|
|
||||||
updated during the training.
|
|
||||||
|
|
||||||
- best_train_epoch: It is the epoch that has the best training loss.
|
|
||||||
|
|
||||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
|
||||||
|
|
||||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
|
||||||
contains number of batches trained so far across
|
|
||||||
epochs.
|
|
||||||
|
|
||||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
|
||||||
|
|
||||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
|
||||||
|
|
||||||
- feature_dim: The model input dim. It has to match the one used
|
|
||||||
in computing features.
|
|
||||||
"""
|
|
||||||
params = AttributeDict(
|
|
||||||
{
|
|
||||||
# training params
|
|
||||||
"best_train_loss": float("inf"),
|
|
||||||
"best_valid_loss": float("inf"),
|
|
||||||
"best_train_epoch": -1,
|
|
||||||
"best_valid_epoch": -1,
|
|
||||||
"batch_idx_train": -1, # 0
|
|
||||||
"log_interval": 50,
|
|
||||||
"valid_interval": 200,
|
|
||||||
"env_info": get_env_info(),
|
|
||||||
"sampling_rate": 48000,
|
|
||||||
"frame_shift": 256,
|
|
||||||
"frame_length": 1024,
|
|
||||||
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
|
|
||||||
"n_mels": 80,
|
|
||||||
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
|
||||||
"lambda_mel": 45.0, # loss scaling coefficient for Mel loss
|
|
||||||
"lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
|
|
||||||
"lambda_dur": 1.0, # loss scaling coefficient for duration loss
|
|
||||||
"lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_if_available(
|
|
||||||
params: AttributeDict, model: nn.Module
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""Load checkpoint from file.
|
|
||||||
|
|
||||||
If params.start_epoch is larger than 1, it will load the checkpoint from
|
|
||||||
`params.start_epoch - 1`.
|
|
||||||
|
|
||||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
|
||||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
|
||||||
and `best_valid_loss` in `params`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params:
|
|
||||||
The return value of :func:`get_params`.
|
|
||||||
model:
|
|
||||||
The training model.
|
|
||||||
Returns:
|
|
||||||
Return a dict containing previously saved training info.
|
|
||||||
"""
|
|
||||||
if params.start_epoch > 1:
|
|
||||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
assert filename.is_file(), f"{filename} does not exist!"
|
|
||||||
|
|
||||||
saved_params = load_checkpoint(filename, model=model)
|
|
||||||
|
|
||||||
keys = [
|
|
||||||
"best_train_epoch",
|
|
||||||
"best_valid_epoch",
|
|
||||||
"batch_idx_train",
|
|
||||||
"best_train_loss",
|
|
||||||
"best_valid_loss",
|
|
||||||
]
|
|
||||||
for k in keys:
|
|
||||||
params[k] = saved_params[k]
|
|
||||||
|
|
||||||
return saved_params
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
|
||||||
mel_loss_params = {
|
|
||||||
"n_mels": params.n_mels,
|
|
||||||
"frame_length": params.frame_length,
|
|
||||||
"frame_shift": params.frame_shift,
|
|
||||||
}
|
|
||||||
model = VITS(
|
|
||||||
vocab_size=params.vocab_size,
|
|
||||||
feature_dim=params.feature_dim,
|
|
||||||
sampling_rate=params.sampling_rate,
|
|
||||||
model_type=params.model_type,
|
|
||||||
mel_loss_params=mel_loss_params,
|
|
||||||
lambda_adv=params.lambda_adv,
|
|
||||||
lambda_mel=params.lambda_mel,
|
|
||||||
lambda_feat_match=params.lambda_feat_match,
|
|
||||||
lambda_dur=params.lambda_dur,
|
|
||||||
lambda_kl=params.lambda_kl,
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
|
||||||
"""Parse batch data"""
|
|
||||||
audio = batch["audio"].to(device)
|
|
||||||
features = batch["features"].to(device)
|
|
||||||
audio_lens = batch["audio_lens"].to(device)
|
|
||||||
features_lens = batch["features_lens"].to(device)
|
|
||||||
tokens = batch["tokens"]
|
|
||||||
|
|
||||||
tokens = tokenizer.tokens_to_token_ids(tokens)
|
|
||||||
tokens = k2.RaggedTensor(tokens)
|
|
||||||
row_splits = tokens.shape.row_splits(1)
|
|
||||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
|
||||||
tokens = tokens.to(device)
|
|
||||||
tokens_lens = tokens_lens.to(device)
|
|
||||||
# a tensor of shape (B, T)
|
|
||||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
|
||||||
|
|
||||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
|
||||||
params: AttributeDict,
|
|
||||||
model: Union[nn.Module, DDP],
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
optimizer_g: Optimizer,
|
|
||||||
optimizer_d: Optimizer,
|
|
||||||
scheduler_g: LRSchedulerType,
|
|
||||||
scheduler_d: LRSchedulerType,
|
|
||||||
train_dl: torch.utils.data.DataLoader,
|
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
|
||||||
scaler: GradScaler,
|
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
|
||||||
world_size: int = 1,
|
|
||||||
rank: int = 0,
|
|
||||||
) -> None:
|
|
||||||
"""Train the model for one epoch.
|
|
||||||
|
|
||||||
The training loss from the mean of all frames is saved in
|
|
||||||
`params.train_loss`. It runs the validation process every
|
|
||||||
`params.valid_interval` batches.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
params:
|
|
||||||
It is returned by :func:`get_params`.
|
|
||||||
model:
|
|
||||||
The model for training.
|
|
||||||
tokenizer:
|
|
||||||
Used to convert text to phonemes.
|
|
||||||
optimizer_g:
|
|
||||||
The optimizer for generator.
|
|
||||||
optimizer_d:
|
|
||||||
The optimizer for discriminator.
|
|
||||||
scheduler_g:
|
|
||||||
The learning rate scheduler for generator, we call step() every epoch.
|
|
||||||
scheduler_d:
|
|
||||||
The learning rate scheduler for discriminator, we call step() every epoch.
|
|
||||||
train_dl:
|
|
||||||
Dataloader for the training dataset.
|
|
||||||
valid_dl:
|
|
||||||
Dataloader for the validation dataset.
|
|
||||||
scaler:
|
|
||||||
The scaler used for mix precision training.
|
|
||||||
tb_writer:
|
|
||||||
Writer to write log messages to tensorboard.
|
|
||||||
world_size:
|
|
||||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
|
||||||
rank:
|
|
||||||
The rank of the node in DDP training. If no DDP is used, it should
|
|
||||||
be set to 0.
|
|
||||||
"""
|
|
||||||
model.train()
|
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
|
||||||
|
|
||||||
# used to track the stats over iterations in one epoch
|
|
||||||
tot_loss = MetricsTracker()
|
|
||||||
|
|
||||||
saved_bad_model = False
|
|
||||||
|
|
||||||
def save_bad_model(suffix: str = ""):
|
|
||||||
save_checkpoint(
|
|
||||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
|
||||||
model=model,
|
|
||||||
params=params,
|
|
||||||
optimizer_g=optimizer_g,
|
|
||||||
optimizer_d=optimizer_d,
|
|
||||||
scheduler_g=scheduler_g,
|
|
||||||
scheduler_d=scheduler_d,
|
|
||||||
sampler=train_dl.sampler,
|
|
||||||
scaler=scaler,
|
|
||||||
rank=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
|
||||||
params.batch_idx_train += 1
|
|
||||||
|
|
||||||
batch_size = len(batch["tokens"])
|
|
||||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
|
||||||
batch, tokenizer, device
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
|
||||||
loss_info["samples"] = batch_size
|
|
||||||
|
|
||||||
try:
|
|
||||||
with autocast(enabled=params.use_fp16):
|
|
||||||
# forward discriminator
|
|
||||||
loss_d, stats_d = model(
|
|
||||||
text=tokens,
|
|
||||||
text_lengths=tokens_lens,
|
|
||||||
feats=features,
|
|
||||||
feats_lengths=features_lens,
|
|
||||||
speech=audio,
|
|
||||||
speech_lengths=audio_lens,
|
|
||||||
forward_generator=False,
|
|
||||||
)
|
|
||||||
for k, v in stats_d.items():
|
|
||||||
loss_info[k] = v * batch_size
|
|
||||||
# update discriminator
|
|
||||||
optimizer_d.zero_grad()
|
|
||||||
scaler.scale(loss_d).backward()
|
|
||||||
scaler.step(optimizer_d)
|
|
||||||
|
|
||||||
with autocast(enabled=params.use_fp16):
|
|
||||||
# forward generator
|
|
||||||
loss_g, stats_g = model(
|
|
||||||
text=tokens,
|
|
||||||
text_lengths=tokens_lens,
|
|
||||||
feats=features,
|
|
||||||
feats_lengths=features_lens,
|
|
||||||
speech=audio,
|
|
||||||
speech_lengths=audio_lens,
|
|
||||||
forward_generator=True,
|
|
||||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
|
||||||
)
|
|
||||||
for k, v in stats_g.items():
|
|
||||||
if "returned_sample" not in k:
|
|
||||||
loss_info[k] = v * batch_size
|
|
||||||
# update generator
|
|
||||||
optimizer_g.zero_grad()
|
|
||||||
scaler.scale(loss_g).backward()
|
|
||||||
scaler.step(optimizer_g)
|
|
||||||
scaler.update()
|
|
||||||
|
|
||||||
# summary stats
|
|
||||||
tot_loss = tot_loss + loss_info
|
|
||||||
except: # noqa
|
|
||||||
save_bad_model()
|
|
||||||
raise
|
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
|
||||||
return
|
|
||||||
|
|
||||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
|
||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
|
||||||
# behavior depending on the current grad scale.
|
|
||||||
cur_grad_scale = scaler._scale.item()
|
|
||||||
|
|
||||||
if cur_grad_scale < 8.0 or (
|
|
||||||
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
|
|
||||||
):
|
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
|
||||||
if cur_grad_scale < 0.01:
|
|
||||||
if not saved_bad_model:
|
|
||||||
save_bad_model(suffix="-first-warning")
|
|
||||||
saved_bad_model = True
|
|
||||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
|
||||||
if cur_grad_scale < 1.0e-05:
|
|
||||||
save_bad_model()
|
|
||||||
raise RuntimeError(
|
|
||||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.batch_idx_train % params.log_interval == 0:
|
|
||||||
cur_lr_g = max(scheduler_g.get_last_lr())
|
|
||||||
cur_lr_d = max(scheduler_d.get_last_lr())
|
|
||||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
|
||||||
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
|
||||||
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
|
||||||
f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
|
|
||||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
|
||||||
)
|
|
||||||
|
|
||||||
if tb_writer is not None:
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/learning_rate_g", cur_lr_g, params.batch_idx_train
|
|
||||||
)
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/learning_rate_d", cur_lr_d, params.batch_idx_train
|
|
||||||
)
|
|
||||||
loss_info.write_summary(
|
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
|
||||||
)
|
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
|
||||||
if params.use_fp16:
|
|
||||||
tb_writer.add_scalar(
|
|
||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
|
||||||
)
|
|
||||||
if "returned_sample" in stats_g:
|
|
||||||
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
|
||||||
tb_writer.add_audio(
|
|
||||||
"train/speech_hat_",
|
|
||||||
speech_hat_,
|
|
||||||
params.batch_idx_train,
|
|
||||||
params.sampling_rate,
|
|
||||||
)
|
|
||||||
tb_writer.add_audio(
|
|
||||||
"train/speech_",
|
|
||||||
speech_,
|
|
||||||
params.batch_idx_train,
|
|
||||||
params.sampling_rate,
|
|
||||||
)
|
|
||||||
tb_writer.add_image(
|
|
||||||
"train/mel_hat_",
|
|
||||||
plot_feature(mel_hat_),
|
|
||||||
params.batch_idx_train,
|
|
||||||
dataformats="HWC",
|
|
||||||
)
|
|
||||||
tb_writer.add_image(
|
|
||||||
"train/mel_",
|
|
||||||
plot_feature(mel_),
|
|
||||||
params.batch_idx_train,
|
|
||||||
dataformats="HWC",
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
params.batch_idx_train % params.valid_interval == 0
|
|
||||||
and not params.print_diagnostics
|
|
||||||
):
|
|
||||||
logging.info("Computing validation loss")
|
|
||||||
valid_info, (speech_hat, speech) = compute_validation_loss(
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
valid_dl=valid_dl,
|
|
||||||
world_size=world_size,
|
|
||||||
)
|
|
||||||
model.train()
|
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
|
||||||
logging.info(
|
|
||||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
|
||||||
)
|
|
||||||
if tb_writer is not None:
|
|
||||||
valid_info.write_summary(
|
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
|
||||||
)
|
|
||||||
tb_writer.add_audio(
|
|
||||||
"train/valdi_speech_hat",
|
|
||||||
speech_hat,
|
|
||||||
params.batch_idx_train,
|
|
||||||
params.sampling_rate,
|
|
||||||
)
|
|
||||||
tb_writer.add_audio(
|
|
||||||
"train/valdi_speech",
|
|
||||||
speech,
|
|
||||||
params.batch_idx_train,
|
|
||||||
params.sampling_rate,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
|
||||||
params.train_loss = loss_value
|
|
||||||
if params.train_loss < params.best_train_loss:
|
|
||||||
params.best_train_epoch = params.cur_epoch
|
|
||||||
params.best_train_loss = params.train_loss
|
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
|
||||||
params: AttributeDict,
|
|
||||||
model: Union[nn.Module, DDP],
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
|
||||||
world_size: int = 1,
|
|
||||||
rank: int = 0,
|
|
||||||
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
|
||||||
"""Run the validation process."""
|
|
||||||
model.eval()
|
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
|
||||||
|
|
||||||
# used to summary the stats over iterations
|
|
||||||
tot_loss = MetricsTracker()
|
|
||||||
returned_sample = None
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
|
||||||
batch_size = len(batch["tokens"])
|
|
||||||
(
|
|
||||||
audio,
|
|
||||||
audio_lens,
|
|
||||||
features,
|
|
||||||
features_lens,
|
|
||||||
tokens,
|
|
||||||
tokens_lens,
|
|
||||||
) = prepare_input(batch, tokenizer, device)
|
|
||||||
|
|
||||||
loss_info = MetricsTracker()
|
|
||||||
loss_info["samples"] = batch_size
|
|
||||||
|
|
||||||
# forward discriminator
|
|
||||||
loss_d, stats_d = model(
|
|
||||||
text=tokens,
|
|
||||||
text_lengths=tokens_lens,
|
|
||||||
feats=features,
|
|
||||||
feats_lengths=features_lens,
|
|
||||||
speech=audio,
|
|
||||||
speech_lengths=audio_lens,
|
|
||||||
forward_generator=False,
|
|
||||||
)
|
|
||||||
assert loss_d.requires_grad is False
|
|
||||||
for k, v in stats_d.items():
|
|
||||||
loss_info[k] = v * batch_size
|
|
||||||
|
|
||||||
# forward generator
|
|
||||||
loss_g, stats_g = model(
|
|
||||||
text=tokens,
|
|
||||||
text_lengths=tokens_lens,
|
|
||||||
feats=features,
|
|
||||||
feats_lengths=features_lens,
|
|
||||||
speech=audio,
|
|
||||||
speech_lengths=audio_lens,
|
|
||||||
forward_generator=True,
|
|
||||||
)
|
|
||||||
assert loss_g.requires_grad is False
|
|
||||||
for k, v in stats_g.items():
|
|
||||||
loss_info[k] = v * batch_size
|
|
||||||
|
|
||||||
# summary stats
|
|
||||||
tot_loss = tot_loss + loss_info
|
|
||||||
|
|
||||||
# infer for first batch:
|
|
||||||
if batch_idx == 0 and rank == 0:
|
|
||||||
inner_model = model.module if isinstance(model, DDP) else model
|
|
||||||
audio_pred, _, duration = inner_model.inference(
|
|
||||||
text=tokens[0, : tokens_lens[0].item()]
|
|
||||||
)
|
|
||||||
audio_pred = audio_pred.data.cpu().numpy()
|
|
||||||
audio_len_pred = (
|
|
||||||
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
|
||||||
)
|
|
||||||
assert audio_len_pred == len(audio_pred), (
|
|
||||||
audio_len_pred,
|
|
||||||
len(audio_pred),
|
|
||||||
)
|
|
||||||
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
|
|
||||||
returned_sample = (audio_pred, audio_gt)
|
|
||||||
|
|
||||||
if world_size > 1:
|
|
||||||
tot_loss.reduce(device)
|
|
||||||
|
|
||||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
|
||||||
if loss_value < params.best_valid_loss:
|
|
||||||
params.best_valid_epoch = params.cur_epoch
|
|
||||||
params.best_valid_loss = loss_value
|
|
||||||
|
|
||||||
return tot_loss, returned_sample
|
|
||||||
|
|
||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
|
||||||
model: Union[nn.Module, DDP],
|
|
||||||
train_dl: torch.utils.data.DataLoader,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
optimizer_g: torch.optim.Optimizer,
|
|
||||||
optimizer_d: torch.optim.Optimizer,
|
|
||||||
params: AttributeDict,
|
|
||||||
):
|
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
|
||||||
)
|
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
|
||||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
|
||||||
for criterion, cuts in batches.items():
|
|
||||||
batch = train_dl.dataset[cuts]
|
|
||||||
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
|
||||||
batch, tokenizer, device
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# for discriminator
|
|
||||||
with autocast(enabled=params.use_fp16):
|
|
||||||
loss_d, stats_d = model(
|
|
||||||
text=tokens,
|
|
||||||
text_lengths=tokens_lens,
|
|
||||||
feats=features,
|
|
||||||
feats_lengths=features_lens,
|
|
||||||
speech=audio,
|
|
||||||
speech_lengths=audio_lens,
|
|
||||||
forward_generator=False,
|
|
||||||
)
|
|
||||||
optimizer_d.zero_grad()
|
|
||||||
loss_d.backward()
|
|
||||||
# for generator
|
|
||||||
with autocast(enabled=params.use_fp16):
|
|
||||||
loss_g, stats_g = model(
|
|
||||||
text=tokens,
|
|
||||||
text_lengths=tokens_lens,
|
|
||||||
feats=features,
|
|
||||||
feats_lengths=features_lens,
|
|
||||||
speech=audio,
|
|
||||||
speech_lengths=audio_lens,
|
|
||||||
forward_generator=True,
|
|
||||||
)
|
|
||||||
optimizer_g.zero_grad()
|
|
||||||
loss_g.backward()
|
|
||||||
except Exception as e:
|
|
||||||
if "CUDA out of memory" in str(e):
|
|
||||||
logging.error(
|
|
||||||
"Your GPU ran out of memory with the current "
|
|
||||||
"max_duration setting. We recommend decreasing "
|
|
||||||
"max_duration and trying again.\n"
|
|
||||||
f"Failing criterion: {criterion} "
|
|
||||||
f"(={crit_values[criterion]}) ..."
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
logging.info(
|
|
||||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run(rank, world_size, args):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
rank:
|
|
||||||
It is a value between 0 and `world_size-1`, which is
|
|
||||||
passed automatically by `mp.spawn()` in :func:`main`.
|
|
||||||
The node with rank 0 is responsible for saving checkpoint.
|
|
||||||
world_size:
|
|
||||||
Number of GPUs for DDP training.
|
|
||||||
args:
|
|
||||||
The return value of get_parser().parse_args()
|
|
||||||
"""
|
|
||||||
params = get_params()
|
|
||||||
params.update(vars(args))
|
|
||||||
|
|
||||||
fix_random_seed(params.seed)
|
|
||||||
if world_size > 1:
|
|
||||||
setup_dist(rank, world_size, params.master_port)
|
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
|
||||||
logging.info("Training started")
|
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
|
||||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
|
||||||
else:
|
|
||||||
tb_writer = None
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", rank)
|
|
||||||
logging.info(f"Device: {device}")
|
|
||||||
|
|
||||||
tokenizer = Tokenizer(params.tokens)
|
|
||||||
params.blank_id = tokenizer.pad_id
|
|
||||||
params.vocab_size = tokenizer.vocab_size
|
|
||||||
|
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
logging.info("About to create model")
|
|
||||||
model = get_model(params)
|
|
||||||
generator = model.generator
|
|
||||||
discriminator = model.discriminator
|
|
||||||
|
|
||||||
num_param_g = sum([p.numel() for p in generator.parameters()])
|
|
||||||
logging.info(f"Number of parameters in generator: {num_param_g}")
|
|
||||||
num_param_d = sum([p.numel() for p in discriminator.parameters()])
|
|
||||||
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
|
||||||
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
|
||||||
|
|
||||||
assert params.start_epoch > 0, params.start_epoch
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
if world_size > 1:
|
|
||||||
logging.info("Using DDP")
|
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
|
||||||
|
|
||||||
optimizer_g = torch.optim.AdamW(
|
|
||||||
generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
|
||||||
)
|
|
||||||
optimizer_d = torch.optim.AdamW(
|
|
||||||
discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
|
||||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
|
|
||||||
|
|
||||||
if checkpoints is not None:
|
|
||||||
# load state_dict for optimizers
|
|
||||||
if "optimizer_g" in checkpoints:
|
|
||||||
logging.info("Loading optimizer_g state dict")
|
|
||||||
optimizer_g.load_state_dict(checkpoints["optimizer_g"])
|
|
||||||
if "optimizer_d" in checkpoints:
|
|
||||||
logging.info("Loading optimizer_d state dict")
|
|
||||||
optimizer_d.load_state_dict(checkpoints["optimizer_d"])
|
|
||||||
|
|
||||||
# load state_dict for schedulers
|
|
||||||
if "scheduler_g" in checkpoints:
|
|
||||||
logging.info("Loading scheduler_g state dict")
|
|
||||||
scheduler_g.load_state_dict(checkpoints["scheduler_g"])
|
|
||||||
if "scheduler_d" in checkpoints:
|
|
||||||
logging.info("Loading scheduler_d state dict")
|
|
||||||
scheduler_d.load_state_dict(checkpoints["scheduler_d"])
|
|
||||||
|
|
||||||
if params.print_diagnostics:
|
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
|
||||||
512
|
|
||||||
) # allow 4 megabytes per sub-module
|
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
|
||||||
|
|
||||||
if params.inf_check:
|
|
||||||
register_inf_check_hooks(model)
|
|
||||||
|
|
||||||
baker_zh = BakerZhSpeechTtsDataModule(args)
|
|
||||||
|
|
||||||
train_cuts = baker_zh.train_cuts()
|
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
|
||||||
# You should use ../local/display_manifest_statistics.py to get
|
|
||||||
# an utterance duration distribution for your dataset to select
|
|
||||||
# the threshold
|
|
||||||
if c.duration < 1.0 or c.duration > 20.0:
|
|
||||||
# logging.warning(
|
|
||||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
|
||||||
# )
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
|
||||||
train_dl = baker_zh.train_dataloaders(train_cuts)
|
|
||||||
|
|
||||||
valid_cuts = baker_zh.valid_cuts()
|
|
||||||
valid_dl = baker_zh.valid_dataloaders(valid_cuts)
|
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
|
||||||
scan_pessimistic_batches_for_oom(
|
|
||||||
model=model,
|
|
||||||
train_dl=train_dl,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
optimizer_g=optimizer_g,
|
|
||||||
optimizer_d=optimizer_d,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
|
||||||
logging.info("Loading grad scaler state dict")
|
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
|
||||||
logging.info(f"Start epoch {epoch}")
|
|
||||||
|
|
||||||
fix_random_seed(params.seed + epoch - 1)
|
|
||||||
train_dl.sampler.set_epoch(epoch - 1)
|
|
||||||
|
|
||||||
params.cur_epoch = epoch
|
|
||||||
|
|
||||||
if tb_writer is not None:
|
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
|
||||||
|
|
||||||
train_one_epoch(
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
optimizer_g=optimizer_g,
|
|
||||||
optimizer_d=optimizer_d,
|
|
||||||
scheduler_g=scheduler_g,
|
|
||||||
scheduler_d=scheduler_d,
|
|
||||||
train_dl=train_dl,
|
|
||||||
valid_dl=valid_dl,
|
|
||||||
scaler=scaler,
|
|
||||||
tb_writer=tb_writer,
|
|
||||||
world_size=world_size,
|
|
||||||
rank=rank,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.print_diagnostics:
|
|
||||||
diagnostic.print_diagnostics()
|
|
||||||
break
|
|
||||||
|
|
||||||
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
|
|
||||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
|
||||||
save_checkpoint(
|
|
||||||
filename=filename,
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
optimizer_g=optimizer_g,
|
|
||||||
optimizer_d=optimizer_d,
|
|
||||||
scheduler_g=scheduler_g,
|
|
||||||
scheduler_d=scheduler_d,
|
|
||||||
sampler=train_dl.sampler,
|
|
||||||
scaler=scaler,
|
|
||||||
rank=rank,
|
|
||||||
)
|
|
||||||
if rank == 0:
|
|
||||||
if params.best_train_epoch == params.cur_epoch:
|
|
||||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
|
||||||
copyfile(src=filename, dst=best_train_filename)
|
|
||||||
|
|
||||||
if params.best_valid_epoch == params.cur_epoch:
|
|
||||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
|
||||||
copyfile(src=filename, dst=best_valid_filename)
|
|
||||||
|
|
||||||
# step per epoch
|
|
||||||
scheduler_g.step()
|
|
||||||
scheduler_d.step()
|
|
||||||
|
|
||||||
logging.info("Done!")
|
|
||||||
|
|
||||||
if world_size > 1:
|
|
||||||
torch.distributed.barrier()
|
|
||||||
cleanup_dist()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = get_parser()
|
|
||||||
BakerZhSpeechTtsDataModule.add_arguments(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.exp_dir = Path(args.exp_dir)
|
|
||||||
|
|
||||||
world_size = args.world_size
|
|
||||||
assert world_size >= 1
|
|
||||||
if world_size > 1:
|
|
||||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
|
||||||
else:
|
|
||||||
run(rank=0, world_size=1, args=args)
|
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/transform.py
|
|
@ -1,330 +0,0 @@
|
|||||||
# Copyright 2021 Piotr Żelasko
|
|
||||||
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
|
|
||||||
# Zengwei Yao)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from functools import lru_cache
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
|
|
||||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|
||||||
CutConcatenate,
|
|
||||||
CutMix,
|
|
||||||
DynamicBucketingSampler,
|
|
||||||
PrecomputedFeatures,
|
|
||||||
SimpleCutSampler,
|
|
||||||
SpecAugment,
|
|
||||||
SpeechSynthesisDataset,
|
|
||||||
)
|
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
|
||||||
AudioSamples,
|
|
||||||
OnTheFlyFeatures,
|
|
||||||
)
|
|
||||||
from lhotse.utils import fix_random_seed
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
|
||||||
def __init__(self, seed: int):
|
|
||||||
self.seed = seed
|
|
||||||
|
|
||||||
def __call__(self, worker_id: int):
|
|
||||||
fix_random_seed(self.seed + worker_id)
|
|
||||||
|
|
||||||
|
|
||||||
class BakerZhSpeechTtsDataModule:
|
|
||||||
"""
|
|
||||||
DataModule for tts experiments.
|
|
||||||
It assumes there is always one train and valid dataloader,
|
|
||||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
|
||||||
and test-other).
|
|
||||||
|
|
||||||
It contains all the common data pipeline modules used in ASR
|
|
||||||
experiments, e.g.:
|
|
||||||
- dynamic batch size,
|
|
||||||
- bucketing samplers,
|
|
||||||
- cut concatenation,
|
|
||||||
- on-the-fly feature extraction
|
|
||||||
|
|
||||||
This class should be derived for specific corpora used in TTS tasks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, args: argparse.Namespace):
|
|
||||||
self.args = args
|
|
||||||
self.sampling_rate = 48000
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
|
||||||
group = parser.add_argument_group(
|
|
||||||
title="TTS data related options",
|
|
||||||
description="These options are used for the preparation of "
|
|
||||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
|
||||||
"effective batch sizes, sampling strategies, applied data "
|
|
||||||
"augmentations, etc.",
|
|
||||||
)
|
|
||||||
|
|
||||||
group.add_argument(
|
|
||||||
"--manifest-dir",
|
|
||||||
type=Path,
|
|
||||||
default=Path("data/spectrogram"),
|
|
||||||
help="Path to directory with train/valid/test cuts.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--max-duration",
|
|
||||||
type=int,
|
|
||||||
default=200.0,
|
|
||||||
help="Maximum pooled recordings duration (seconds) in a "
|
|
||||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--bucketing-sampler",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="When enabled, the batches will come from buckets of "
|
|
||||||
"similar duration (saves padding frames).",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--num-buckets",
|
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help="The number of buckets for the DynamicBucketingSampler"
|
|
||||||
"(you might want to increase it for larger datasets).",
|
|
||||||
)
|
|
||||||
|
|
||||||
group.add_argument(
|
|
||||||
"--on-the-fly-feats",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="When enabled, use on-the-fly cut mixing and feature "
|
|
||||||
"extraction. Will drop existing precomputed feature manifests "
|
|
||||||
"if available.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--shuffle",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="When enabled (=default), the examples will be "
|
|
||||||
"shuffled for each epoch.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--drop-last",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="Whether to drop last batch. Used by sampler.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--return-cuts",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="When enabled, each batch will have the "
|
|
||||||
"field: batch['cut'] with the cuts that "
|
|
||||||
"were used to construct it.",
|
|
||||||
)
|
|
||||||
group.add_argument(
|
|
||||||
"--num-workers",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="The number of training dataloader workers that "
|
|
||||||
"collect the batches.",
|
|
||||||
)
|
|
||||||
|
|
||||||
group.add_argument(
|
|
||||||
"--input-strategy",
|
|
||||||
type=str,
|
|
||||||
default="PrecomputedFeatures",
|
|
||||||
help="AudioSamples or PrecomputedFeatures",
|
|
||||||
)
|
|
||||||
|
|
||||||
def train_dataloaders(
|
|
||||||
self,
|
|
||||||
cuts_train: CutSet,
|
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> DataLoader:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
cuts_train:
|
|
||||||
CutSet for training.
|
|
||||||
sampler_state_dict:
|
|
||||||
The state dict for the training sampler.
|
|
||||||
"""
|
|
||||||
logging.info("About to create train dataset")
|
|
||||||
train = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.on_the_fly_feats:
|
|
||||||
sampling_rate = self.sampling_rate
|
|
||||||
config = SpectrogramConfig(
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
frame_length=1024 / sampling_rate, # (in second),
|
|
||||||
frame_shift=256 / sampling_rate, # (in second)
|
|
||||||
use_fft_mag=True,
|
|
||||||
)
|
|
||||||
train = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.bucketing_sampler:
|
|
||||||
logging.info("Using DynamicBucketingSampler.")
|
|
||||||
train_sampler = DynamicBucketingSampler(
|
|
||||||
cuts_train,
|
|
||||||
max_duration=self.args.max_duration,
|
|
||||||
shuffle=self.args.shuffle,
|
|
||||||
num_buckets=self.args.num_buckets,
|
|
||||||
buffer_size=self.args.num_buckets * 2000,
|
|
||||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
|
||||||
drop_last=self.args.drop_last,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logging.info("Using SimpleCutSampler.")
|
|
||||||
train_sampler = SimpleCutSampler(
|
|
||||||
cuts_train,
|
|
||||||
max_duration=self.args.max_duration,
|
|
||||||
shuffle=self.args.shuffle,
|
|
||||||
)
|
|
||||||
logging.info("About to create train dataloader")
|
|
||||||
|
|
||||||
if sampler_state_dict is not None:
|
|
||||||
logging.info("Loading sampler state dict")
|
|
||||||
train_sampler.load_state_dict(sampler_state_dict)
|
|
||||||
|
|
||||||
# 'seed' is derived from the current random state, which will have
|
|
||||||
# previously been set in the main process.
|
|
||||||
seed = torch.randint(0, 100000, ()).item()
|
|
||||||
worker_init_fn = _SeedWorkers(seed)
|
|
||||||
|
|
||||||
train_dl = DataLoader(
|
|
||||||
train,
|
|
||||||
sampler=train_sampler,
|
|
||||||
batch_size=None,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
persistent_workers=False,
|
|
||||||
worker_init_fn=worker_init_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
return train_dl
|
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
|
||||||
logging.info("About to create dev dataset")
|
|
||||||
if self.args.on_the_fly_feats:
|
|
||||||
sampling_rate = self.sampling_rate
|
|
||||||
config = SpectrogramConfig(
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
frame_length=1024 / sampling_rate, # (in second),
|
|
||||||
frame_shift=256 / sampling_rate, # (in second)
|
|
||||||
use_fft_mag=True,
|
|
||||||
)
|
|
||||||
validate = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
validate = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
|
||||||
valid_sampler = DynamicBucketingSampler(
|
|
||||||
cuts_valid,
|
|
||||||
max_duration=self.args.max_duration,
|
|
||||||
num_buckets=self.args.num_buckets,
|
|
||||||
shuffle=False,
|
|
||||||
)
|
|
||||||
logging.info("About to create valid dataloader")
|
|
||||||
valid_dl = DataLoader(
|
|
||||||
validate,
|
|
||||||
sampler=valid_sampler,
|
|
||||||
batch_size=None,
|
|
||||||
num_workers=2,
|
|
||||||
persistent_workers=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return valid_dl
|
|
||||||
|
|
||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
|
||||||
logging.info("About to create test dataset")
|
|
||||||
if self.args.on_the_fly_feats:
|
|
||||||
sampling_rate = self.sampling_rate
|
|
||||||
config = SpectrogramConfig(
|
|
||||||
sampling_rate=sampling_rate,
|
|
||||||
frame_length=1024 / sampling_rate, # (in second),
|
|
||||||
frame_shift=256 / sampling_rate, # (in second)
|
|
||||||
use_fft_mag=True,
|
|
||||||
)
|
|
||||||
test = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
test = SpeechSynthesisDataset(
|
|
||||||
return_text=False,
|
|
||||||
return_tokens=True,
|
|
||||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
|
||||||
test_sampler = DynamicBucketingSampler(
|
|
||||||
cuts,
|
|
||||||
max_duration=self.args.max_duration,
|
|
||||||
num_buckets=self.args.num_buckets,
|
|
||||||
shuffle=False,
|
|
||||||
)
|
|
||||||
logging.info("About to create test dataloader")
|
|
||||||
test_dl = DataLoader(
|
|
||||||
test,
|
|
||||||
batch_size=None,
|
|
||||||
sampler=test_sampler,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
)
|
|
||||||
return test_dl
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def train_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get train cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def valid_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get validation cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz"
|
|
||||||
)
|
|
||||||
|
|
||||||
@lru_cache()
|
|
||||||
def test_cuts(self) -> CutSet:
|
|
||||||
logging.info("About to get test cuts")
|
|
||||||
return load_manifest_lazy(
|
|
||||||
self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz"
|
|
||||||
)
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/utils.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/vits.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/wavenet.py
|
|
Loading…
x
Reference in New Issue
Block a user