mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
first working version
This commit is contained in:
parent
5723ce85c8
commit
8b867affee
3
.gitignore
vendored
3
.gitignore
vendored
@ -36,3 +36,6 @@ node_modules
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
*.fst
|
*.fst
|
||||||
*.arpa
|
*.arpa
|
||||||
|
core.c
|
||||||
|
*.so
|
||||||
|
build
|
||||||
|
|||||||
@ -19,7 +19,7 @@ Install extra dependencies
|
|||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||||
pip install numba espnet_tts_frontend
|
pip install numba espnet_tts_frontend cython
|
||||||
|
|
||||||
Data preparation
|
Data preparation
|
||||||
----------------
|
----------------
|
||||||
|
|||||||
7
egs/baker_zh/TTS/local/README.md
Normal file
7
egs/baker_zh/TTS/local/README.md
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# Introduction
|
||||||
|
|
||||||
|
[./symbols.py](./symbols.py) is copied from
|
||||||
|
https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py
|
||||||
|
|
||||||
|
[./pypinyin-local.dict](./pypinyin-local.dict) is copied from
|
||||||
|
https://github.com/UEhQZXI/vits_chinese/blob/master/misc/pypinyin-local.dict
|
||||||
0
egs/baker_zh/TTS/local/__init__.py
Normal file
0
egs/baker_zh/TTS/local/__init__.py
Normal file
106
egs/baker_zh/TTS/local/compute_spectrogram_baker.py
Executable file
106
egs/baker_zh/TTS/local/compute_spectrogram_baker.py
Executable file
@ -0,0 +1,106 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file computes fbank features of the baker_zh dataset.
|
||||||
|
It looks for manifests in the directory data/manifests.
|
||||||
|
|
||||||
|
The generated spectrogram features are saved in data/spectrogram.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
Spectrogram,
|
||||||
|
SpectrogramConfig,
|
||||||
|
load_manifest,
|
||||||
|
)
|
||||||
|
from lhotse.audio import RecordingSet
|
||||||
|
from lhotse.supervision import SupervisionSet
|
||||||
|
|
||||||
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_spectrogram_baker_zh():
|
||||||
|
src_dir = Path("data/manifests")
|
||||||
|
output_dir = Path("data/spectrogram")
|
||||||
|
num_jobs = min(4, os.cpu_count())
|
||||||
|
|
||||||
|
sampling_rate = 48000
|
||||||
|
frame_length = 1024 / sampling_rate # (in second)
|
||||||
|
frame_shift = 256 / sampling_rate # (in second)
|
||||||
|
use_fft_mag = True
|
||||||
|
|
||||||
|
prefix = "baker_zh"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
|
partition = "all"
|
||||||
|
|
||||||
|
recordings = load_manifest(
|
||||||
|
src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet
|
||||||
|
)
|
||||||
|
supervisions = load_manifest(
|
||||||
|
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
|
||||||
|
)
|
||||||
|
|
||||||
|
config = SpectrogramConfig(
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
frame_length=frame_length,
|
||||||
|
frame_shift=frame_shift,
|
||||||
|
use_fft_mag=use_fft_mag,
|
||||||
|
)
|
||||||
|
extractor = Spectrogram(config)
|
||||||
|
|
||||||
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
|
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||||
|
if (output_dir / cuts_filename).is_file():
|
||||||
|
logging.info(f"{cuts_filename} already exists - skipping.")
|
||||||
|
return
|
||||||
|
logging.info(f"Processing {partition}")
|
||||||
|
cut_set = CutSet.from_manifests(
|
||||||
|
recordings=recordings, supervisions=supervisions
|
||||||
|
)
|
||||||
|
|
||||||
|
cut_set = cut_set.compute_and_store_features(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||||
|
# when an executor is specified, make more partitions
|
||||||
|
num_jobs=num_jobs if ex is None else 80,
|
||||||
|
executor=ex,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
)
|
||||||
|
cut_set.to_file(output_dir / cuts_filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
compute_spectrogram_baker_zh()
|
||||||
421
egs/baker_zh/TTS/local/pinyin_dict.py
Normal file
421
egs/baker_zh/TTS/local/pinyin_dict.py
Normal file
@ -0,0 +1,421 @@
|
|||||||
|
# This dict is copied from
|
||||||
|
# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py
|
||||||
|
pinyin_dict = {
|
||||||
|
"a": ("^", "a"),
|
||||||
|
"ai": ("^", "ai"),
|
||||||
|
"an": ("^", "an"),
|
||||||
|
"ang": ("^", "ang"),
|
||||||
|
"ao": ("^", "ao"),
|
||||||
|
"ba": ("b", "a"),
|
||||||
|
"bai": ("b", "ai"),
|
||||||
|
"ban": ("b", "an"),
|
||||||
|
"bang": ("b", "ang"),
|
||||||
|
"bao": ("b", "ao"),
|
||||||
|
"be": ("b", "e"),
|
||||||
|
"bei": ("b", "ei"),
|
||||||
|
"ben": ("b", "en"),
|
||||||
|
"beng": ("b", "eng"),
|
||||||
|
"bi": ("b", "i"),
|
||||||
|
"bian": ("b", "ian"),
|
||||||
|
"biao": ("b", "iao"),
|
||||||
|
"bie": ("b", "ie"),
|
||||||
|
"bin": ("b", "in"),
|
||||||
|
"bing": ("b", "ing"),
|
||||||
|
"bo": ("b", "o"),
|
||||||
|
"bu": ("b", "u"),
|
||||||
|
"ca": ("c", "a"),
|
||||||
|
"cai": ("c", "ai"),
|
||||||
|
"can": ("c", "an"),
|
||||||
|
"cang": ("c", "ang"),
|
||||||
|
"cao": ("c", "ao"),
|
||||||
|
"ce": ("c", "e"),
|
||||||
|
"cen": ("c", "en"),
|
||||||
|
"ceng": ("c", "eng"),
|
||||||
|
"cha": ("ch", "a"),
|
||||||
|
"chai": ("ch", "ai"),
|
||||||
|
"chan": ("ch", "an"),
|
||||||
|
"chang": ("ch", "ang"),
|
||||||
|
"chao": ("ch", "ao"),
|
||||||
|
"che": ("ch", "e"),
|
||||||
|
"chen": ("ch", "en"),
|
||||||
|
"cheng": ("ch", "eng"),
|
||||||
|
"chi": ("ch", "iii"),
|
||||||
|
"chong": ("ch", "ong"),
|
||||||
|
"chou": ("ch", "ou"),
|
||||||
|
"chu": ("ch", "u"),
|
||||||
|
"chua": ("ch", "ua"),
|
||||||
|
"chuai": ("ch", "uai"),
|
||||||
|
"chuan": ("ch", "uan"),
|
||||||
|
"chuang": ("ch", "uang"),
|
||||||
|
"chui": ("ch", "uei"),
|
||||||
|
"chun": ("ch", "uen"),
|
||||||
|
"chuo": ("ch", "uo"),
|
||||||
|
"ci": ("c", "ii"),
|
||||||
|
"cong": ("c", "ong"),
|
||||||
|
"cou": ("c", "ou"),
|
||||||
|
"cu": ("c", "u"),
|
||||||
|
"cuan": ("c", "uan"),
|
||||||
|
"cui": ("c", "uei"),
|
||||||
|
"cun": ("c", "uen"),
|
||||||
|
"cuo": ("c", "uo"),
|
||||||
|
"da": ("d", "a"),
|
||||||
|
"dai": ("d", "ai"),
|
||||||
|
"dan": ("d", "an"),
|
||||||
|
"dang": ("d", "ang"),
|
||||||
|
"dao": ("d", "ao"),
|
||||||
|
"de": ("d", "e"),
|
||||||
|
"dei": ("d", "ei"),
|
||||||
|
"den": ("d", "en"),
|
||||||
|
"deng": ("d", "eng"),
|
||||||
|
"di": ("d", "i"),
|
||||||
|
"dia": ("d", "ia"),
|
||||||
|
"dian": ("d", "ian"),
|
||||||
|
"diao": ("d", "iao"),
|
||||||
|
"die": ("d", "ie"),
|
||||||
|
"ding": ("d", "ing"),
|
||||||
|
"diu": ("d", "iou"),
|
||||||
|
"dong": ("d", "ong"),
|
||||||
|
"dou": ("d", "ou"),
|
||||||
|
"du": ("d", "u"),
|
||||||
|
"duan": ("d", "uan"),
|
||||||
|
"dui": ("d", "uei"),
|
||||||
|
"dun": ("d", "uen"),
|
||||||
|
"duo": ("d", "uo"),
|
||||||
|
"e": ("^", "e"),
|
||||||
|
"ei": ("^", "ei"),
|
||||||
|
"en": ("^", "en"),
|
||||||
|
"ng": ("^", "en"),
|
||||||
|
"eng": ("^", "eng"),
|
||||||
|
"er": ("^", "er"),
|
||||||
|
"fa": ("f", "a"),
|
||||||
|
"fan": ("f", "an"),
|
||||||
|
"fang": ("f", "ang"),
|
||||||
|
"fei": ("f", "ei"),
|
||||||
|
"fen": ("f", "en"),
|
||||||
|
"feng": ("f", "eng"),
|
||||||
|
"fo": ("f", "o"),
|
||||||
|
"fou": ("f", "ou"),
|
||||||
|
"fu": ("f", "u"),
|
||||||
|
"ga": ("g", "a"),
|
||||||
|
"gai": ("g", "ai"),
|
||||||
|
"gan": ("g", "an"),
|
||||||
|
"gang": ("g", "ang"),
|
||||||
|
"gao": ("g", "ao"),
|
||||||
|
"ge": ("g", "e"),
|
||||||
|
"gei": ("g", "ei"),
|
||||||
|
"gen": ("g", "en"),
|
||||||
|
"geng": ("g", "eng"),
|
||||||
|
"gong": ("g", "ong"),
|
||||||
|
"gou": ("g", "ou"),
|
||||||
|
"gu": ("g", "u"),
|
||||||
|
"gua": ("g", "ua"),
|
||||||
|
"guai": ("g", "uai"),
|
||||||
|
"guan": ("g", "uan"),
|
||||||
|
"guang": ("g", "uang"),
|
||||||
|
"gui": ("g", "uei"),
|
||||||
|
"gun": ("g", "uen"),
|
||||||
|
"guo": ("g", "uo"),
|
||||||
|
"ha": ("h", "a"),
|
||||||
|
"hai": ("h", "ai"),
|
||||||
|
"han": ("h", "an"),
|
||||||
|
"hang": ("h", "ang"),
|
||||||
|
"hao": ("h", "ao"),
|
||||||
|
"he": ("h", "e"),
|
||||||
|
"hei": ("h", "ei"),
|
||||||
|
"hen": ("h", "en"),
|
||||||
|
"heng": ("h", "eng"),
|
||||||
|
"hong": ("h", "ong"),
|
||||||
|
"hou": ("h", "ou"),
|
||||||
|
"hu": ("h", "u"),
|
||||||
|
"hua": ("h", "ua"),
|
||||||
|
"huai": ("h", "uai"),
|
||||||
|
"huan": ("h", "uan"),
|
||||||
|
"huang": ("h", "uang"),
|
||||||
|
"hui": ("h", "uei"),
|
||||||
|
"hun": ("h", "uen"),
|
||||||
|
"huo": ("h", "uo"),
|
||||||
|
"ji": ("j", "i"),
|
||||||
|
"jia": ("j", "ia"),
|
||||||
|
"jian": ("j", "ian"),
|
||||||
|
"jiang": ("j", "iang"),
|
||||||
|
"jiao": ("j", "iao"),
|
||||||
|
"jie": ("j", "ie"),
|
||||||
|
"jin": ("j", "in"),
|
||||||
|
"jing": ("j", "ing"),
|
||||||
|
"jiong": ("j", "iong"),
|
||||||
|
"jiu": ("j", "iou"),
|
||||||
|
"ju": ("j", "v"),
|
||||||
|
"juan": ("j", "van"),
|
||||||
|
"jue": ("j", "ve"),
|
||||||
|
"jun": ("j", "vn"),
|
||||||
|
"ka": ("k", "a"),
|
||||||
|
"kai": ("k", "ai"),
|
||||||
|
"kan": ("k", "an"),
|
||||||
|
"kang": ("k", "ang"),
|
||||||
|
"kao": ("k", "ao"),
|
||||||
|
"ke": ("k", "e"),
|
||||||
|
"kei": ("k", "ei"),
|
||||||
|
"ken": ("k", "en"),
|
||||||
|
"keng": ("k", "eng"),
|
||||||
|
"kong": ("k", "ong"),
|
||||||
|
"kou": ("k", "ou"),
|
||||||
|
"ku": ("k", "u"),
|
||||||
|
"kua": ("k", "ua"),
|
||||||
|
"kuai": ("k", "uai"),
|
||||||
|
"kuan": ("k", "uan"),
|
||||||
|
"kuang": ("k", "uang"),
|
||||||
|
"kui": ("k", "uei"),
|
||||||
|
"kun": ("k", "uen"),
|
||||||
|
"kuo": ("k", "uo"),
|
||||||
|
"la": ("l", "a"),
|
||||||
|
"lai": ("l", "ai"),
|
||||||
|
"lan": ("l", "an"),
|
||||||
|
"lang": ("l", "ang"),
|
||||||
|
"lao": ("l", "ao"),
|
||||||
|
"le": ("l", "e"),
|
||||||
|
"lei": ("l", "ei"),
|
||||||
|
"leng": ("l", "eng"),
|
||||||
|
"li": ("l", "i"),
|
||||||
|
"lia": ("l", "ia"),
|
||||||
|
"lian": ("l", "ian"),
|
||||||
|
"liang": ("l", "iang"),
|
||||||
|
"liao": ("l", "iao"),
|
||||||
|
"lie": ("l", "ie"),
|
||||||
|
"lin": ("l", "in"),
|
||||||
|
"ling": ("l", "ing"),
|
||||||
|
"liu": ("l", "iou"),
|
||||||
|
"lo": ("l", "o"),
|
||||||
|
"long": ("l", "ong"),
|
||||||
|
"lou": ("l", "ou"),
|
||||||
|
"lu": ("l", "u"),
|
||||||
|
"lv": ("l", "v"),
|
||||||
|
"luan": ("l", "uan"),
|
||||||
|
"lve": ("l", "ve"),
|
||||||
|
"lue": ("l", "ve"),
|
||||||
|
"lun": ("l", "uen"),
|
||||||
|
"luo": ("l", "uo"),
|
||||||
|
"ma": ("m", "a"),
|
||||||
|
"mai": ("m", "ai"),
|
||||||
|
"man": ("m", "an"),
|
||||||
|
"mang": ("m", "ang"),
|
||||||
|
"mao": ("m", "ao"),
|
||||||
|
"me": ("m", "e"),
|
||||||
|
"mei": ("m", "ei"),
|
||||||
|
"men": ("m", "en"),
|
||||||
|
"meng": ("m", "eng"),
|
||||||
|
"mi": ("m", "i"),
|
||||||
|
"mian": ("m", "ian"),
|
||||||
|
"miao": ("m", "iao"),
|
||||||
|
"mie": ("m", "ie"),
|
||||||
|
"min": ("m", "in"),
|
||||||
|
"ming": ("m", "ing"),
|
||||||
|
"miu": ("m", "iou"),
|
||||||
|
"mo": ("m", "o"),
|
||||||
|
"mou": ("m", "ou"),
|
||||||
|
"mu": ("m", "u"),
|
||||||
|
"na": ("n", "a"),
|
||||||
|
"nai": ("n", "ai"),
|
||||||
|
"nan": ("n", "an"),
|
||||||
|
"nang": ("n", "ang"),
|
||||||
|
"nao": ("n", "ao"),
|
||||||
|
"ne": ("n", "e"),
|
||||||
|
"nei": ("n", "ei"),
|
||||||
|
"nen": ("n", "en"),
|
||||||
|
"neng": ("n", "eng"),
|
||||||
|
"ni": ("n", "i"),
|
||||||
|
"nia": ("n", "ia"),
|
||||||
|
"nian": ("n", "ian"),
|
||||||
|
"niang": ("n", "iang"),
|
||||||
|
"niao": ("n", "iao"),
|
||||||
|
"nie": ("n", "ie"),
|
||||||
|
"nin": ("n", "in"),
|
||||||
|
"ning": ("n", "ing"),
|
||||||
|
"niu": ("n", "iou"),
|
||||||
|
"nong": ("n", "ong"),
|
||||||
|
"nou": ("n", "ou"),
|
||||||
|
"nu": ("n", "u"),
|
||||||
|
"nv": ("n", "v"),
|
||||||
|
"nuan": ("n", "uan"),
|
||||||
|
"nve": ("n", "ve"),
|
||||||
|
"nue": ("n", "ve"),
|
||||||
|
"nuo": ("n", "uo"),
|
||||||
|
"o": ("^", "o"),
|
||||||
|
"ou": ("^", "ou"),
|
||||||
|
"pa": ("p", "a"),
|
||||||
|
"pai": ("p", "ai"),
|
||||||
|
"pan": ("p", "an"),
|
||||||
|
"pang": ("p", "ang"),
|
||||||
|
"pao": ("p", "ao"),
|
||||||
|
"pe": ("p", "e"),
|
||||||
|
"pei": ("p", "ei"),
|
||||||
|
"pen": ("p", "en"),
|
||||||
|
"peng": ("p", "eng"),
|
||||||
|
"pi": ("p", "i"),
|
||||||
|
"pian": ("p", "ian"),
|
||||||
|
"piao": ("p", "iao"),
|
||||||
|
"pie": ("p", "ie"),
|
||||||
|
"pin": ("p", "in"),
|
||||||
|
"ping": ("p", "ing"),
|
||||||
|
"po": ("p", "o"),
|
||||||
|
"pou": ("p", "ou"),
|
||||||
|
"pu": ("p", "u"),
|
||||||
|
"qi": ("q", "i"),
|
||||||
|
"qia": ("q", "ia"),
|
||||||
|
"qian": ("q", "ian"),
|
||||||
|
"qiang": ("q", "iang"),
|
||||||
|
"qiao": ("q", "iao"),
|
||||||
|
"qie": ("q", "ie"),
|
||||||
|
"qin": ("q", "in"),
|
||||||
|
"qing": ("q", "ing"),
|
||||||
|
"qiong": ("q", "iong"),
|
||||||
|
"qiu": ("q", "iou"),
|
||||||
|
"qu": ("q", "v"),
|
||||||
|
"quan": ("q", "van"),
|
||||||
|
"que": ("q", "ve"),
|
||||||
|
"qun": ("q", "vn"),
|
||||||
|
"ran": ("r", "an"),
|
||||||
|
"rang": ("r", "ang"),
|
||||||
|
"rao": ("r", "ao"),
|
||||||
|
"re": ("r", "e"),
|
||||||
|
"ren": ("r", "en"),
|
||||||
|
"reng": ("r", "eng"),
|
||||||
|
"ri": ("r", "iii"),
|
||||||
|
"rong": ("r", "ong"),
|
||||||
|
"rou": ("r", "ou"),
|
||||||
|
"ru": ("r", "u"),
|
||||||
|
"rua": ("r", "ua"),
|
||||||
|
"ruan": ("r", "uan"),
|
||||||
|
"rui": ("r", "uei"),
|
||||||
|
"run": ("r", "uen"),
|
||||||
|
"ruo": ("r", "uo"),
|
||||||
|
"sa": ("s", "a"),
|
||||||
|
"sai": ("s", "ai"),
|
||||||
|
"san": ("s", "an"),
|
||||||
|
"sang": ("s", "ang"),
|
||||||
|
"sao": ("s", "ao"),
|
||||||
|
"se": ("s", "e"),
|
||||||
|
"sen": ("s", "en"),
|
||||||
|
"seng": ("s", "eng"),
|
||||||
|
"sha": ("sh", "a"),
|
||||||
|
"shai": ("sh", "ai"),
|
||||||
|
"shan": ("sh", "an"),
|
||||||
|
"shang": ("sh", "ang"),
|
||||||
|
"shao": ("sh", "ao"),
|
||||||
|
"she": ("sh", "e"),
|
||||||
|
"shei": ("sh", "ei"),
|
||||||
|
"shen": ("sh", "en"),
|
||||||
|
"sheng": ("sh", "eng"),
|
||||||
|
"shi": ("sh", "iii"),
|
||||||
|
"shou": ("sh", "ou"),
|
||||||
|
"shu": ("sh", "u"),
|
||||||
|
"shua": ("sh", "ua"),
|
||||||
|
"shuai": ("sh", "uai"),
|
||||||
|
"shuan": ("sh", "uan"),
|
||||||
|
"shuang": ("sh", "uang"),
|
||||||
|
"shui": ("sh", "uei"),
|
||||||
|
"shun": ("sh", "uen"),
|
||||||
|
"shuo": ("sh", "uo"),
|
||||||
|
"si": ("s", "ii"),
|
||||||
|
"song": ("s", "ong"),
|
||||||
|
"sou": ("s", "ou"),
|
||||||
|
"su": ("s", "u"),
|
||||||
|
"suan": ("s", "uan"),
|
||||||
|
"sui": ("s", "uei"),
|
||||||
|
"sun": ("s", "uen"),
|
||||||
|
"suo": ("s", "uo"),
|
||||||
|
"ta": ("t", "a"),
|
||||||
|
"tai": ("t", "ai"),
|
||||||
|
"tan": ("t", "an"),
|
||||||
|
"tang": ("t", "ang"),
|
||||||
|
"tao": ("t", "ao"),
|
||||||
|
"te": ("t", "e"),
|
||||||
|
"tei": ("t", "ei"),
|
||||||
|
"teng": ("t", "eng"),
|
||||||
|
"ti": ("t", "i"),
|
||||||
|
"tian": ("t", "ian"),
|
||||||
|
"tiao": ("t", "iao"),
|
||||||
|
"tie": ("t", "ie"),
|
||||||
|
"ting": ("t", "ing"),
|
||||||
|
"tong": ("t", "ong"),
|
||||||
|
"tou": ("t", "ou"),
|
||||||
|
"tu": ("t", "u"),
|
||||||
|
"tuan": ("t", "uan"),
|
||||||
|
"tui": ("t", "uei"),
|
||||||
|
"tun": ("t", "uen"),
|
||||||
|
"tuo": ("t", "uo"),
|
||||||
|
"wa": ("^", "ua"),
|
||||||
|
"wai": ("^", "uai"),
|
||||||
|
"wan": ("^", "uan"),
|
||||||
|
"wang": ("^", "uang"),
|
||||||
|
"wei": ("^", "uei"),
|
||||||
|
"wen": ("^", "uen"),
|
||||||
|
"weng": ("^", "ueng"),
|
||||||
|
"wo": ("^", "uo"),
|
||||||
|
"wu": ("^", "u"),
|
||||||
|
"xi": ("x", "i"),
|
||||||
|
"xia": ("x", "ia"),
|
||||||
|
"xian": ("x", "ian"),
|
||||||
|
"xiang": ("x", "iang"),
|
||||||
|
"xiao": ("x", "iao"),
|
||||||
|
"xie": ("x", "ie"),
|
||||||
|
"xin": ("x", "in"),
|
||||||
|
"xing": ("x", "ing"),
|
||||||
|
"xiong": ("x", "iong"),
|
||||||
|
"xiu": ("x", "iou"),
|
||||||
|
"xu": ("x", "v"),
|
||||||
|
"xuan": ("x", "van"),
|
||||||
|
"xue": ("x", "ve"),
|
||||||
|
"xun": ("x", "vn"),
|
||||||
|
"ya": ("^", "ia"),
|
||||||
|
"yan": ("^", "ian"),
|
||||||
|
"yang": ("^", "iang"),
|
||||||
|
"yao": ("^", "iao"),
|
||||||
|
"ye": ("^", "ie"),
|
||||||
|
"yi": ("^", "i"),
|
||||||
|
"yin": ("^", "in"),
|
||||||
|
"ying": ("^", "ing"),
|
||||||
|
"yo": ("^", "iou"),
|
||||||
|
"yong": ("^", "iong"),
|
||||||
|
"you": ("^", "iou"),
|
||||||
|
"yu": ("^", "v"),
|
||||||
|
"yuan": ("^", "van"),
|
||||||
|
"yue": ("^", "ve"),
|
||||||
|
"yun": ("^", "vn"),
|
||||||
|
"za": ("z", "a"),
|
||||||
|
"zai": ("z", "ai"),
|
||||||
|
"zan": ("z", "an"),
|
||||||
|
"zang": ("z", "ang"),
|
||||||
|
"zao": ("z", "ao"),
|
||||||
|
"ze": ("z", "e"),
|
||||||
|
"zei": ("z", "ei"),
|
||||||
|
"zen": ("z", "en"),
|
||||||
|
"zeng": ("z", "eng"),
|
||||||
|
"zha": ("zh", "a"),
|
||||||
|
"zhai": ("zh", "ai"),
|
||||||
|
"zhan": ("zh", "an"),
|
||||||
|
"zhang": ("zh", "ang"),
|
||||||
|
"zhao": ("zh", "ao"),
|
||||||
|
"zhe": ("zh", "e"),
|
||||||
|
"zhei": ("zh", "ei"),
|
||||||
|
"zhen": ("zh", "en"),
|
||||||
|
"zheng": ("zh", "eng"),
|
||||||
|
"zhi": ("zh", "iii"),
|
||||||
|
"zhong": ("zh", "ong"),
|
||||||
|
"zhou": ("zh", "ou"),
|
||||||
|
"zhu": ("zh", "u"),
|
||||||
|
"zhua": ("zh", "ua"),
|
||||||
|
"zhuai": ("zh", "uai"),
|
||||||
|
"zhuan": ("zh", "uan"),
|
||||||
|
"zhuang": ("zh", "uang"),
|
||||||
|
"zhui": ("zh", "uei"),
|
||||||
|
"zhun": ("zh", "uen"),
|
||||||
|
"zhuo": ("zh", "uo"),
|
||||||
|
"zi": ("z", "ii"),
|
||||||
|
"zong": ("z", "ong"),
|
||||||
|
"zou": ("z", "ou"),
|
||||||
|
"zu": ("z", "u"),
|
||||||
|
"zuan": ("z", "uan"),
|
||||||
|
"zui": ("z", "uei"),
|
||||||
|
"zun": ("z", "uen"),
|
||||||
|
"zuo": ("z", "uo"),
|
||||||
|
}
|
||||||
53
egs/baker_zh/TTS/local/prepare_token_file.py
Executable file
53
egs/baker_zh/TTS/local/prepare_token_file.py
Executable file
@ -0,0 +1,53 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file generates the file that maps tokens to IDs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
from symbols import symbols
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/tokens.txt"),
|
||||||
|
help="Path to the dict that maps the text tokens to IDs",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
tokens = Path(args.tokens)
|
||||||
|
|
||||||
|
with open(tokens, "w", encoding="utf-8") as f:
|
||||||
|
for token_id, token in enumerate(symbols):
|
||||||
|
f.write(f"{token} {token_id}\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
59
egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py
Executable file
59
egs/baker_zh/TTS/local/prepare_tokens_baker_zh.py
Executable file
@ -0,0 +1,59 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file reads the texts in given manifest and save the new cuts with tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lhotse import CutSet, load_manifest
|
||||||
|
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_tokens_baker_zh():
|
||||||
|
output_dir = Path("data/spectrogram")
|
||||||
|
prefix = "baker_zh"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
|
partition = "all"
|
||||||
|
|
||||||
|
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||||
|
|
||||||
|
tokenizer = Tokenizer()
|
||||||
|
|
||||||
|
new_cuts = []
|
||||||
|
i = 0
|
||||||
|
for cut in cut_set:
|
||||||
|
# Each cut only contains one supervision
|
||||||
|
assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
|
||||||
|
text = cut.supervisions[0].normalized_text
|
||||||
|
cut.tokens = tokenizer.text_to_tokens(text)
|
||||||
|
|
||||||
|
new_cuts.append(cut)
|
||||||
|
|
||||||
|
new_cut_set = CutSet.from_cuts(new_cuts)
|
||||||
|
new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
prepare_tokens_baker_zh()
|
||||||
328
egs/baker_zh/TTS/local/pypinyin-local.dict
Normal file
328
egs/baker_zh/TTS/local/pypinyin-local.dict
Normal file
@ -0,0 +1,328 @@
|
|||||||
|
姐姐 jie3 jie
|
||||||
|
宝宝 bao3 bao
|
||||||
|
哥哥 ge1 ge
|
||||||
|
妹妹 mei4 mei
|
||||||
|
弟弟 di4 di
|
||||||
|
妈妈 ma1 ma
|
||||||
|
开心哦 kai1 xin1 o
|
||||||
|
爸爸 ba4 ba
|
||||||
|
秘密哟 mi4 mi4 yo
|
||||||
|
哦 o
|
||||||
|
一年 yi4 nian2
|
||||||
|
一夜 yi2 ye4
|
||||||
|
一切 yi2 qie4
|
||||||
|
一座 yi2 zuo4
|
||||||
|
一下 yi2 xia4
|
||||||
|
上一山 shang4 yi2 shan1
|
||||||
|
下一山 xia4 yi2 shan1
|
||||||
|
休息 xiu1 xi2
|
||||||
|
东西 dong1 xi
|
||||||
|
上一届 shang4 yi2 jie4
|
||||||
|
便宜 pian2 yi4
|
||||||
|
加长 jia1 chang2
|
||||||
|
单田芳 shan4 tian2 fang1
|
||||||
|
帧 zhen1
|
||||||
|
长时间 chang2 shi2 jian1
|
||||||
|
长时 chang2 shi2
|
||||||
|
识别 shi2 bie2
|
||||||
|
生命中 sheng1 ming4 zhong1
|
||||||
|
踏实 ta1 shi
|
||||||
|
嗯 en4
|
||||||
|
溜达 liu1 da
|
||||||
|
少儿 shao4 er2
|
||||||
|
爷爷 ye2 ye
|
||||||
|
不是 bu2 shi4
|
||||||
|
一圈 yi1 quan1
|
||||||
|
厜读一声 zui1 du2 yi4 sheng1
|
||||||
|
一种 yi4 zhong3
|
||||||
|
一簇簇 yi2 cu4 cu4
|
||||||
|
一个 yi2 ge4
|
||||||
|
一样 yi2 yang4
|
||||||
|
一跩一跩 yi4 zhuai3 yi4 zhuai3
|
||||||
|
一会儿 yi2 hui4 er
|
||||||
|
一幢 yi2 zhuang4
|
||||||
|
挨了 ai2 le
|
||||||
|
熬菜 ao1 cai4
|
||||||
|
扒鸡 pa2 ji1
|
||||||
|
背枪 bei1 qiang1
|
||||||
|
绷瓷儿 beng4 ci2 er2
|
||||||
|
绷劲儿 beng3 jin4 er
|
||||||
|
绷着脸 beng3 zhe lian3
|
||||||
|
藏医 zang4 yi1
|
||||||
|
噌吰 cheng1 hong2
|
||||||
|
差点儿 cha4 dian3 er
|
||||||
|
差失 cha1 shi1
|
||||||
|
差误 cha1 wu4
|
||||||
|
孱头 can4 tou
|
||||||
|
乘间 cheng2 jian4
|
||||||
|
锄镰棘矜 chu2 lian2 ji2 qin2
|
||||||
|
川藏 chuan1 zang4
|
||||||
|
穿著 chuan1 zhuo2
|
||||||
|
答讪 da1 shan4
|
||||||
|
答言 da1 yan2
|
||||||
|
大伯子 da4 bai3 zi
|
||||||
|
大夫 dai4 fu
|
||||||
|
弹冠 tan2 guan1
|
||||||
|
当间 dang1 jian4
|
||||||
|
当然咯 dang1 ran2 lo
|
||||||
|
点种 dian3 zhong3
|
||||||
|
垛好 duo4 hao3
|
||||||
|
发疟子 fa1 yao4 zi
|
||||||
|
饭熟了 fan4 shou2 le
|
||||||
|
附著 fu4 zhuo2
|
||||||
|
复沓 fu4 ta4
|
||||||
|
供稿 gong1 gao3
|
||||||
|
供养 gong1 yang3
|
||||||
|
骨朵 gu1 duo
|
||||||
|
骨碌 gu1 lu
|
||||||
|
果脯 guo3 fu3
|
||||||
|
哈什玛 ha4 shi2 ma3
|
||||||
|
海蜇 hai3 zhe2
|
||||||
|
呵欠 he1 qian
|
||||||
|
河水汤汤 he2 shui3 shang1 shang1
|
||||||
|
鹄立 hu2 li4
|
||||||
|
鹄望 hu2 wang4
|
||||||
|
混人 hun2 ren2
|
||||||
|
混水 hun2 shui3
|
||||||
|
鸡血 ji1 xie3
|
||||||
|
缉鞋口 qi1 xie2 kou3
|
||||||
|
亟来闻讯 qi4 lai2 wen2 xun4
|
||||||
|
计量 ji4 liang2
|
||||||
|
济水 ji3 shui3
|
||||||
|
间杂 jian4 za2
|
||||||
|
脚跐两只船 jiao3 ci3 liang3 zhi1 chuan2
|
||||||
|
脚儿 jue2 er2
|
||||||
|
口角 kou3 jiao3
|
||||||
|
勒石 le4 shi2
|
||||||
|
累进 lei3 jin4
|
||||||
|
累累如丧家之犬 lei2 lei2 ru2 sang4 jia1 zhi1 quan3
|
||||||
|
累年 lei3 nian2
|
||||||
|
脸涨通红 lian3 zhang4 tong1 hong2
|
||||||
|
踉锵 liang4 qiang1
|
||||||
|
燎眉毛 liao3 mei2 mao2
|
||||||
|
燎头发 liao3 tou2 fa4
|
||||||
|
溜达 liu1 da
|
||||||
|
溜缝儿 liu4 feng4 er
|
||||||
|
馏口饭 liu4 kou3 fan4
|
||||||
|
遛马 liu4 ma3
|
||||||
|
遛鸟 liu4 niao3
|
||||||
|
遛弯儿 liu4 wan1 er
|
||||||
|
楼枪机 lou1 qiang1 ji1
|
||||||
|
搂钱 lou1 qian2
|
||||||
|
鹿脯 lu4 fu3
|
||||||
|
露头 lou4 tou2
|
||||||
|
落魄 luo4 po4
|
||||||
|
捋胡子 lv3 hu2 zi
|
||||||
|
绿地 lv4 di4
|
||||||
|
麦垛 mai4 duo4
|
||||||
|
没劲儿 mei2 jin4 er
|
||||||
|
闷棍 men4 gun4
|
||||||
|
闷葫芦 men4 hu2 lu
|
||||||
|
闷头干 men1 tou2 gan4
|
||||||
|
蒙古 meng3 gu3
|
||||||
|
靡日不思 mi3 ri4 bu4 si1
|
||||||
|
缪姓 miao4 xing4
|
||||||
|
抹墙 mo4 qiang2
|
||||||
|
抹下脸 ma1 xia4 lian3
|
||||||
|
泥子 ni4 zi
|
||||||
|
拗不过 niu4 bu guo4
|
||||||
|
排车 pai3 che1
|
||||||
|
盘诘 pan2 jie2
|
||||||
|
膀肿 pang1 zhong3
|
||||||
|
炮干 bao1 gan1
|
||||||
|
炮格 pao2 ge2
|
||||||
|
碰钉子 peng4 ding1 zi
|
||||||
|
缥色 piao3 se4
|
||||||
|
瀑河 bao4 he2
|
||||||
|
蹊径 xi1 jing4
|
||||||
|
前后相属 qian2 hou4 xiang1 zhu3
|
||||||
|
翘尾巴 qiao4 wei3 ba
|
||||||
|
趄坡儿 qie4 po1 er
|
||||||
|
秦桧 qin2 hui4
|
||||||
|
圈马 juan1 ma3
|
||||||
|
雀盲眼 qiao3 mang2 yan3
|
||||||
|
雀子 qiao1 zi
|
||||||
|
三年五载 san1 nian2 wu3 zai3
|
||||||
|
加载 jia1 zai3
|
||||||
|
山大王 shan1 dai4 wang
|
||||||
|
苫屋草 shan4 wu1 cao3
|
||||||
|
数数 shu3 shu4
|
||||||
|
说客 shui4 ke4
|
||||||
|
思量 si1 liang2
|
||||||
|
伺侯 ci4 hou
|
||||||
|
踏实 ta1 shi
|
||||||
|
提溜 di1 liu
|
||||||
|
调拨 diao4 bo1
|
||||||
|
帖子 tie3 zi
|
||||||
|
铜钿 tong2 tian2
|
||||||
|
头昏脑涨 tou2 hun1 nao3 zhang4
|
||||||
|
褪色 tui4 se4
|
||||||
|
褪着手 tun4 zhe shou3
|
||||||
|
圩子 wei2 zi
|
||||||
|
尾巴 wei3 ba
|
||||||
|
系好船只 xi4 hao3 chuan2 zhi1
|
||||||
|
系好马匹 xi4 hao3 ma3 pi3
|
||||||
|
杏脯 xing4 fu3
|
||||||
|
姓单 xing4 shan4
|
||||||
|
姓葛 xing4 ge3
|
||||||
|
姓哈 xing4 ha3
|
||||||
|
姓解 xing4 xie4
|
||||||
|
姓秘 xing4 bi4
|
||||||
|
姓宁 xing4 ning4
|
||||||
|
旋风 xuan4 feng1
|
||||||
|
旋根车轴 xuan4 gen1 che1 zhou2
|
||||||
|
荨麻 qian2 ma2
|
||||||
|
一幢楼房 yi1 zhuang4 lou2 fang2
|
||||||
|
遗之千金 wei4 zhi1 qian1 jin1
|
||||||
|
殷殷 yin3 yin3
|
||||||
|
应招 ying4 zhao1
|
||||||
|
用称约 yong4 cheng4 yao1
|
||||||
|
约斤肉 yao1 jin1 rou4
|
||||||
|
晕机 yun4 ji1
|
||||||
|
熨贴 yu4 tie1
|
||||||
|
咋办 za3 ban4
|
||||||
|
咋呼 zha1 hu
|
||||||
|
仔兽 zi3 shou4
|
||||||
|
扎彩 za1 cai3
|
||||||
|
扎实 zha1 shi
|
||||||
|
扎腰带 za1 yao1 dai4
|
||||||
|
轧朋友 ga2 peng2 you3
|
||||||
|
爪子 zhua3 zi
|
||||||
|
折腾 zhe1 teng
|
||||||
|
着实 zhuo2 shi2
|
||||||
|
着我旧时裳 zhuo2 wo3 jiu4 shi2 chang2
|
||||||
|
枝蔓 zhi1 man4
|
||||||
|
中鹄 zhong1 hu2
|
||||||
|
中选 zhong4 xuan3
|
||||||
|
猪圈 zhu1 juan4
|
||||||
|
拽住不放 zhuai4 zhu4 bu4 fang4
|
||||||
|
转悠 zhuan4 you
|
||||||
|
庄稼熟了 zhuang1 jia shou2 le
|
||||||
|
酌量 zhuo2 liang2
|
||||||
|
罪行累累 zui4 xing2 lei3 lei3
|
||||||
|
一手 yi4 shou3
|
||||||
|
一去不复返 yi2 qu4 bu2 fu4 fan3
|
||||||
|
一颗 yi4 ke1
|
||||||
|
一件 yi2 jian4
|
||||||
|
一斤 yi4 jin1
|
||||||
|
一点 yi4 dian3
|
||||||
|
一朵 yi4 duo3
|
||||||
|
一声 yi4 sheng1
|
||||||
|
一身 yi4 shen1
|
||||||
|
不要 bu2 yao4
|
||||||
|
一人 yi4 ren2
|
||||||
|
一个 yi2 ge4
|
||||||
|
一把 yi4 ba3
|
||||||
|
一门 yi4 men2
|
||||||
|
一門 yi4 men2
|
||||||
|
一艘 yi4 sou1
|
||||||
|
一片 yi2 pian4
|
||||||
|
一篇 yi2 pian1
|
||||||
|
一份 yi2 fen4
|
||||||
|
好嗲 hao3 dia3
|
||||||
|
随地 sui2 di4
|
||||||
|
扁担长 bian3 dan4 chang3
|
||||||
|
一堆 yi4 dui1
|
||||||
|
不义 bu2 yi4
|
||||||
|
放一放 fang4 yi2 fang4
|
||||||
|
一米 yi4 mi3
|
||||||
|
一顿 yi2 dun4
|
||||||
|
一层楼 yi4 ceng2 lou2
|
||||||
|
一条 yi4 tiao2
|
||||||
|
一件 yi2 jian4
|
||||||
|
一棵 yi4 ke1
|
||||||
|
一小股 yi4 xiao3 gu3
|
||||||
|
一拐一拐 yi4 guai3 yi4 guai3
|
||||||
|
一根 yi4 gen1
|
||||||
|
沆瀣一气 hang4 xie4 yi2 qi4
|
||||||
|
一丝 yi4 si1
|
||||||
|
一毫 yi4 hao2
|
||||||
|
一樣 yi2 yang4
|
||||||
|
处处 chu4 chu4
|
||||||
|
一餐 yi4 can
|
||||||
|
永不 yong3 bu2
|
||||||
|
一看 yi2 kan4
|
||||||
|
一架 yi2 jia4
|
||||||
|
送还 song4 huan2
|
||||||
|
一见 yi2 jian4
|
||||||
|
一座 yi2 zuo4
|
||||||
|
一块 yi2 kuai4
|
||||||
|
一天 yi4 tian1
|
||||||
|
一只 yi4 zhi1
|
||||||
|
一支 yi4 zhi1
|
||||||
|
一字 yi2 zi4
|
||||||
|
一句 yi2 ju4
|
||||||
|
一张 yi4 zhang1
|
||||||
|
一條 yi4 tiao2
|
||||||
|
一场 yi4 chang3
|
||||||
|
一粒 yi2 li4
|
||||||
|
小俩口 xiao3 liang3 kou3
|
||||||
|
一首 yi4 shou3
|
||||||
|
一对 yi2 dui4
|
||||||
|
一手 yi4 shou3
|
||||||
|
又一村 you4 yi4 cun1
|
||||||
|
一概而论 yi2 gai4 er2 lun4
|
||||||
|
一峰峰 yi4 feng1 feng1
|
||||||
|
不但 bu2 dan4
|
||||||
|
一笑 yi2 xiao4
|
||||||
|
挠痒痒 nao2 yang3 yang
|
||||||
|
不对 bu2 dui4
|
||||||
|
拧开 ning3 kai1
|
||||||
|
爱不释手 ai4 bu2 shi4 shou3
|
||||||
|
一念 yi2 nian4
|
||||||
|
夺得 duo2 de2
|
||||||
|
一袭 yi4 xi2
|
||||||
|
一定 yi2 ding4
|
||||||
|
不慎 bu2 shen4
|
||||||
|
剽窃 piao2 qie4
|
||||||
|
一时 yi4 shi2
|
||||||
|
撇开 pie3 kai1
|
||||||
|
一祭 yi2 ji4
|
||||||
|
发卡 fa4 qia3
|
||||||
|
少不了 shao3 bu4 liao3
|
||||||
|
千虑一失 qian1 lv4 yi4 shi1
|
||||||
|
呛得 qiang4 de2
|
||||||
|
切菜 qie1 cai4
|
||||||
|
茄盒 qie2 he2
|
||||||
|
不去 bu2 qu4
|
||||||
|
一大圈 yi2 da4 quan1
|
||||||
|
不再 bu2 zai4
|
||||||
|
一群 yi4 qun2
|
||||||
|
不必 bu2 bi4
|
||||||
|
一些 yi4 xie1
|
||||||
|
一路 yi2 lu4
|
||||||
|
一股 yi4 gu3
|
||||||
|
一到 yi2 dao4
|
||||||
|
一拨 yi4 bo1
|
||||||
|
一排 yi4 pai2
|
||||||
|
一空 yi4 kong1
|
||||||
|
吮吸着 shun3 xi1 zhe
|
||||||
|
不适合 bu2 shi4 he2
|
||||||
|
一串串 yi2 chuan4 chuan4
|
||||||
|
一提起 yi4 ti2 qi3
|
||||||
|
一尘不染 yi4 chen2 bu4 ran3
|
||||||
|
一生 yi4 sheng1
|
||||||
|
一派 yi2 pai4
|
||||||
|
不断 bu2 duan4
|
||||||
|
一次 yi2 ci4
|
||||||
|
不进步 bu2 jin4 bu4
|
||||||
|
娃娃 wa2 wa
|
||||||
|
万户侯 wan4 hu4 hou2
|
||||||
|
一方 yi4 fang1
|
||||||
|
一番话 yi4 fan1 hua4
|
||||||
|
一遍 yi2 bian4
|
||||||
|
不计较 bu2 ji4 jiao4
|
||||||
|
诇 xiong4
|
||||||
|
一边 yi4 bian1
|
||||||
|
一束 yi2 shu4
|
||||||
|
一听到 yi4 ting1 dao4
|
||||||
|
炸鸡 zha2 ji1
|
||||||
|
乍暧还寒 zha4 ai4 huan2 han2
|
||||||
|
我说诶 wo3 shuo1 ei1
|
||||||
|
棒诶 bang4 ei1
|
||||||
|
寒碜 han2 chen4
|
||||||
|
应采儿 ying4 cai3 er2
|
||||||
|
晕车 yun1 che1
|
||||||
|
必应 bi4 ying4
|
||||||
|
应援 ying4 yuan2
|
||||||
|
应力 ying4 li4
|
||||||
73
egs/baker_zh/TTS/local/symbols.py
Normal file
73
egs/baker_zh/TTS/local/symbols.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# This file is copied from
|
||||||
|
# https://github.com/UEhQZXI/vits_chinese/blob/master/text/symbols.py
|
||||||
|
_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"]
|
||||||
|
|
||||||
|
_initials = [
|
||||||
|
"^",
|
||||||
|
"b",
|
||||||
|
"c",
|
||||||
|
"ch",
|
||||||
|
"d",
|
||||||
|
"f",
|
||||||
|
"g",
|
||||||
|
"h",
|
||||||
|
"j",
|
||||||
|
"k",
|
||||||
|
"l",
|
||||||
|
"m",
|
||||||
|
"n",
|
||||||
|
"p",
|
||||||
|
"q",
|
||||||
|
"r",
|
||||||
|
"s",
|
||||||
|
"sh",
|
||||||
|
"t",
|
||||||
|
"x",
|
||||||
|
"z",
|
||||||
|
"zh",
|
||||||
|
]
|
||||||
|
|
||||||
|
_tones = ["1", "2", "3", "4", "5"]
|
||||||
|
|
||||||
|
_finals = [
|
||||||
|
"a",
|
||||||
|
"ai",
|
||||||
|
"an",
|
||||||
|
"ang",
|
||||||
|
"ao",
|
||||||
|
"e",
|
||||||
|
"ei",
|
||||||
|
"en",
|
||||||
|
"eng",
|
||||||
|
"er",
|
||||||
|
"i",
|
||||||
|
"ia",
|
||||||
|
"ian",
|
||||||
|
"iang",
|
||||||
|
"iao",
|
||||||
|
"ie",
|
||||||
|
"ii",
|
||||||
|
"iii",
|
||||||
|
"in",
|
||||||
|
"ing",
|
||||||
|
"iong",
|
||||||
|
"iou",
|
||||||
|
"o",
|
||||||
|
"ong",
|
||||||
|
"ou",
|
||||||
|
"u",
|
||||||
|
"ua",
|
||||||
|
"uai",
|
||||||
|
"uan",
|
||||||
|
"uang",
|
||||||
|
"uei",
|
||||||
|
"uen",
|
||||||
|
"ueng",
|
||||||
|
"uo",
|
||||||
|
"v",
|
||||||
|
"van",
|
||||||
|
"ve",
|
||||||
|
"vn",
|
||||||
|
]
|
||||||
|
|
||||||
|
symbols = _pause + _initials + [i + j for i in _finals for j in _tones]
|
||||||
137
egs/baker_zh/TTS/local/tokenizer.py
Normal file
137
egs/baker_zh/TTS/local/tokenizer.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
# This file is modified from
|
||||||
|
# https://github.com/UEhQZXI/vits_chinese/blob/master/vits_strings.py
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
# Note pinyin_dict is from ./pinyin_dict.py
|
||||||
|
from pinyin_dict import pinyin_dict
|
||||||
|
from pypinyin import Style
|
||||||
|
from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin
|
||||||
|
from pypinyin.converter import DefaultConverter
|
||||||
|
from pypinyin.core import Pinyin, load_phrases_dict
|
||||||
|
|
||||||
|
|
||||||
|
class _MyConverter(NeutralToneWith5Mixin, DefaultConverter):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
def __init__(self, tokens: str = ""):
|
||||||
|
self._load_pinyin_dict()
|
||||||
|
self._pinyin_parser = Pinyin(_MyConverter())
|
||||||
|
|
||||||
|
if tokens != "":
|
||||||
|
self._load_tokens(tokens)
|
||||||
|
|
||||||
|
def texts_to_token_ids(self, texts: List[str], **kwargs) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
A list of sentences.
|
||||||
|
kwargs:
|
||||||
|
Not used. It is for compatibility with other TTS recipes in icefall.
|
||||||
|
"""
|
||||||
|
tokens = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
tokens.append(self.text_to_tokens(text))
|
||||||
|
|
||||||
|
return self.tokens_to_token_ids(tokens)
|
||||||
|
|
||||||
|
def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]:
|
||||||
|
ans = []
|
||||||
|
|
||||||
|
for token_list in tokens:
|
||||||
|
token_ids = []
|
||||||
|
for t in token_list:
|
||||||
|
if t not in self.token2id:
|
||||||
|
logging.warning(f"Skip OOV {t}")
|
||||||
|
continue
|
||||||
|
token_ids.append(self.token2id[t])
|
||||||
|
ans.append(token_ids)
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def text_to_tokens(self, text: str) -> List[str]:
|
||||||
|
# Convert "," to ["sp", "sil"]
|
||||||
|
# Convert "。" to ["sil"]
|
||||||
|
# append ["eos"] at the end of a sentence
|
||||||
|
phonemes = ["sil"]
|
||||||
|
pinyins = self._pinyin_parser.pinyin(
|
||||||
|
text,
|
||||||
|
style=Style.TONE3,
|
||||||
|
errors=lambda x: [[w] for w in x],
|
||||||
|
)
|
||||||
|
|
||||||
|
new_pinyin = []
|
||||||
|
for p in pinyins:
|
||||||
|
p = p[0]
|
||||||
|
if p == ",":
|
||||||
|
new_pinyin.extend(["sp", "sil"])
|
||||||
|
elif p == "。":
|
||||||
|
new_pinyin.append("sil")
|
||||||
|
else:
|
||||||
|
new_pinyin.append(p)
|
||||||
|
sub_phonemes = self._get_phoneme4pinyin(new_pinyin)
|
||||||
|
sub_phonemes.append("eos")
|
||||||
|
phonemes.extend(sub_phonemes)
|
||||||
|
return phonemes
|
||||||
|
|
||||||
|
def _get_phoneme4pinyin(self, pinyins):
|
||||||
|
result = []
|
||||||
|
for pinyin in pinyins:
|
||||||
|
if pinyin in ("sil", "sp"):
|
||||||
|
result.append(pinyin)
|
||||||
|
elif pinyin[:-1] in pinyin_dict:
|
||||||
|
tone = pinyin[-1]
|
||||||
|
a = pinyin[:-1]
|
||||||
|
a1, a2 = pinyin_dict[a]
|
||||||
|
# every word is appended with a #0
|
||||||
|
result += [a1, a2 + tone, "#0"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _load_pinyin_dict(self):
|
||||||
|
this_dir = Path(__file__).parent.resolve()
|
||||||
|
my_dict = {}
|
||||||
|
with open(f"{this_dir}/pypinyin-local.dict", "r", encoding="utf-8") as f:
|
||||||
|
content = f.readlines()
|
||||||
|
for line in content:
|
||||||
|
cuts = line.strip().split()
|
||||||
|
hanzi = cuts[0]
|
||||||
|
pinyin = cuts[1:]
|
||||||
|
my_dict[hanzi] = [[p] for p in pinyin]
|
||||||
|
|
||||||
|
load_phrases_dict(my_dict)
|
||||||
|
|
||||||
|
def _load_tokens(self, filename):
|
||||||
|
token2id: Dict[str, int] = {}
|
||||||
|
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
info = line.rstrip().split()
|
||||||
|
if len(info) == 1:
|
||||||
|
# case of space
|
||||||
|
token = " "
|
||||||
|
idx = int(info[0])
|
||||||
|
else:
|
||||||
|
token, idx = info[0], int(info[1])
|
||||||
|
|
||||||
|
assert token not in token2id, token
|
||||||
|
|
||||||
|
token2id[token] = idx
|
||||||
|
|
||||||
|
self.token2id = token2id
|
||||||
|
self.vocab_size = len(self.token2id)
|
||||||
|
self.pad_id = self.token2id["#0"]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
tokenizer = Tokenizer()
|
||||||
|
tokenizer._sentence_to_ids("你好,好的。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
egs/baker_zh/TTS/local/validate_manifest.py
Symbolic link
1
egs/baker_zh/TTS/local/validate_manifest.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../ljspeech/TTS/local/validate_manifest.py
|
||||||
@ -0,0 +1,124 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
dl_dir=$PWD/download
|
||||||
|
|
||||||
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
# All files generated by this script are saved in "data".
|
||||||
|
# You can safely remove "data" and rerun this script to regenerate it.
|
||||||
|
mkdir -p data
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
log "dl_dir: $dl_dir"
|
||||||
|
|
||||||
|
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||||
|
log "Stage 0: build monotonic_align lib"
|
||||||
|
if [ ! -d vits/monotonic_align/build ]; then
|
||||||
|
cd vits/monotonic_align
|
||||||
|
python3 setup.py build_ext --inplace
|
||||||
|
cd ../../
|
||||||
|
else
|
||||||
|
log "monotonic_align lib already built"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
log "Stage 1: Download data"
|
||||||
|
|
||||||
|
# The directory $dl_dir/BZNSYP will contain 3 sub directories:
|
||||||
|
# - PhoneLabeling
|
||||||
|
# - ProsodyLabeling
|
||||||
|
# - Wave
|
||||||
|
|
||||||
|
# If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink
|
||||||
|
#
|
||||||
|
# ln -sfv /path/to/BZNSYP $dl_dir/
|
||||||
|
# touch $dl_dir/BZNSYP/.completed
|
||||||
|
#
|
||||||
|
if [ ! -d $dl_dir/BZNSYP ]; then
|
||||||
|
lhotse download baker-zh $dl_dir
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
log "Stage 2: Prepare baker-zh manifest"
|
||||||
|
# We assume that you have downloaded the baker corpus
|
||||||
|
# to $dl_dir/BZNSYP
|
||||||
|
mkdir -p data/manifests
|
||||||
|
if [ ! -e data/manifests/.baker.done ]; then
|
||||||
|
lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests
|
||||||
|
touch data/manifests/.baker.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
log "Stage 3: Compute spectrogram for baker (may take 3 minutes)"
|
||||||
|
mkdir -p data/spectrogram
|
||||||
|
if [ ! -e data/spectrogram/.baker.done ]; then
|
||||||
|
./local/compute_spectrogram_baker.py
|
||||||
|
touch data/spectrogram/.baker.done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -e data/spectrogram/.baker-validated.done ]; then
|
||||||
|
log "Validating data/spectrogram for baker"
|
||||||
|
python3 ./local/validate_manifest.py \
|
||||||
|
data/spectrogram/baker_zh_cuts_all.jsonl.gz
|
||||||
|
touch data/spectrogram/.baker-validated.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
log "Stage 4: Prepare tokens for baker-zh (may take 20 seconds)"
|
||||||
|
if [ ! -e data/spectrogram/.baker_zh_with_token.done ]; then
|
||||||
|
|
||||||
|
./local/prepare_tokens_baker_zh.py
|
||||||
|
|
||||||
|
mv -v data/spectrogram/baker_zh_cuts_with_tokens_all.jsonl.gz \
|
||||||
|
data/spectrogram/baker_zh_cuts_all.jsonl.gz
|
||||||
|
|
||||||
|
touch data/spectrogram/.baker_zh_with_token.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
|
log "Stage 5: Split the baker-zh cuts into train, valid and test sets (may take 25 seconds)"
|
||||||
|
if [ ! -e data/spectrogram/.baker_zh_split.done ]; then
|
||||||
|
lhotse subset --last 600 \
|
||||||
|
data/spectrogram/baker_zh_cuts_all.jsonl.gz \
|
||||||
|
data/spectrogram/baker_zh_cuts_validtest.jsonl.gz
|
||||||
|
lhotse subset --first 100 \
|
||||||
|
data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \
|
||||||
|
data/spectrogram/baker_zh_cuts_valid.jsonl.gz
|
||||||
|
lhotse subset --last 500 \
|
||||||
|
data/spectrogram/baker_zh_cuts_validtest.jsonl.gz \
|
||||||
|
data/spectrogram/baker_zh_cuts_test.jsonl.gz
|
||||||
|
|
||||||
|
rm data/spectrogram/baker_zh_cuts_validtest.jsonl.gz
|
||||||
|
|
||||||
|
n=$(( $(gunzip -c data/spectrogram/baker_zh_cuts_all.jsonl.gz | wc -l) - 600 ))
|
||||||
|
lhotse subset --first $n \
|
||||||
|
data/spectrogram/baker_zh_cuts_all.jsonl.gz \
|
||||||
|
data/spectrogram/baker_zh_cuts_train.jsonl.gz
|
||||||
|
touch data/spectrogram/.baker_zh_split.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
|
log "Stage 6: Generate token file"
|
||||||
|
if [ ! -e data/tokens.txt ]; then
|
||||||
|
./local/prepare_token_file.py --tokens data/tokens.txt
|
||||||
|
fi
|
||||||
|
fi
|
||||||
414
egs/baker_zh/TTS/vits/export-onnx.py
Executable file
414
egs/baker_zh/TTS/vits/export-onnx.py
Executable file
@ -0,0 +1,414 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a VITS model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
Export the model to ONNX:
|
||||||
|
./vits/export-onnx.py \
|
||||||
|
--epoch 1000 \
|
||||||
|
--exp-dir vits/exp \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
|
||||||
|
It will generate one file inside vits/exp:
|
||||||
|
- vits-epoch-1000.onnx
|
||||||
|
|
||||||
|
See ./test_onnx.py for how to use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from train import get_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="vits/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
help="""Path to vocabulary.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-type",
|
||||||
|
type=str,
|
||||||
|
default="high",
|
||||||
|
choices=["low", "medium", "high"],
|
||||||
|
help="""If not empty, valid values are: low, medium, high.
|
||||||
|
It controls the model size. low -> runs faster.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = str(value)
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel(nn.Module):
|
||||||
|
"""A wrapper for VITS generator."""
|
||||||
|
|
||||||
|
def __init__(self, model: nn.Module):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
A VITS generator.
|
||||||
|
frame_shift:
|
||||||
|
The frame shift in samples.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
tokens_lens: torch.Tensor,
|
||||||
|
noise_scale: float = 0.667,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
noise_scale_dur: float = 0.8,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Please see the help information of VITS.inference_batch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
Input text token indexes (1, T_text)
|
||||||
|
tokens_lens:
|
||||||
|
Number of tokens of shape (1,)
|
||||||
|
noise_scale (float):
|
||||||
|
Noise scale parameter for flow.
|
||||||
|
noise_scale_dur (float):
|
||||||
|
Noise scale parameter for duration predictor.
|
||||||
|
alpha (float):
|
||||||
|
Alpha parameter to control the speed of generated speech.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- audio, generated wavform tensor, (B, T_wav)
|
||||||
|
"""
|
||||||
|
audio, _, _ = self.model.generator.inference(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
noise_scale=noise_scale,
|
||||||
|
noise_scale_dur=noise_scale_dur,
|
||||||
|
alpha=alpha,
|
||||||
|
)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def export_model_onnx(
|
||||||
|
model: nn.Module,
|
||||||
|
model_filename: str,
|
||||||
|
vocab_size: int,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given generator model to ONNX format.
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- tokens, a tensor of shape (1, T_text); dtype is torch.int64
|
||||||
|
|
||||||
|
and it has one output:
|
||||||
|
|
||||||
|
- audio, a tensor of shape (1, T'); dtype is torch.float32
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The VITS generator.
|
||||||
|
model_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
vocab_size:
|
||||||
|
Number of tokens used in training.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
tokens = torch.randint(low=0, high=vocab_size, size=(1, 13), dtype=torch.int64)
|
||||||
|
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
|
||||||
|
noise_scale = torch.tensor([1], dtype=torch.float32)
|
||||||
|
noise_scale_dur = torch.tensor([1], dtype=torch.float32)
|
||||||
|
alpha = torch.tensor([1], dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(tokens, tokens_lens, noise_scale, alpha, noise_scale_dur),
|
||||||
|
model_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"tokens",
|
||||||
|
"tokens_lens",
|
||||||
|
"noise_scale",
|
||||||
|
"alpha",
|
||||||
|
"noise_scale_dur",
|
||||||
|
],
|
||||||
|
output_names=["audio"],
|
||||||
|
dynamic_axes={
|
||||||
|
"tokens": {0: "N", 1: "T"},
|
||||||
|
"tokens_lens": {0: "N"},
|
||||||
|
"audio": {0: "N", 1: "T"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if model.model.spks is None:
|
||||||
|
num_speakers = 1
|
||||||
|
else:
|
||||||
|
num_speakers = model.model.spks
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "vits",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "icefall", # must be icefall for models from icefall
|
||||||
|
"language": "Chinese",
|
||||||
|
"n_speakers": num_speakers,
|
||||||
|
"sample_rate": model.model.sampling_rate, # Must match the real sample rate
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=model_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(params.tokens)
|
||||||
|
params.blank_id = tokenizer.pad_id
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model = OnnxModel(model=model)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"generator parameters: {num_param}, or {num_param/1000/1000} M")
|
||||||
|
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
model_filename = params.exp_dir / f"vits-{suffix}.onnx"
|
||||||
|
export_model_onnx(
|
||||||
|
model,
|
||||||
|
model_filename,
|
||||||
|
params.vocab_size,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported generator to {model_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
||||||
|
|
||||||
|
"""
|
||||||
|
Supported languages.
|
||||||
|
|
||||||
|
LJSpeech is using "en-us" from the second column.
|
||||||
|
|
||||||
|
Pty Language Age/Gender VoiceName File Other Languages
|
||||||
|
5 af --/M Afrikaans gmw/af
|
||||||
|
5 am --/M Amharic sem/am
|
||||||
|
5 an --/M Aragonese roa/an
|
||||||
|
5 ar --/M Arabic sem/ar
|
||||||
|
5 as --/M Assamese inc/as
|
||||||
|
5 az --/M Azerbaijani trk/az
|
||||||
|
5 ba --/M Bashkir trk/ba
|
||||||
|
5 be --/M Belarusian zle/be
|
||||||
|
5 bg --/M Bulgarian zls/bg
|
||||||
|
5 bn --/M Bengali inc/bn
|
||||||
|
5 bpy --/M Bishnupriya_Manipuri inc/bpy
|
||||||
|
5 bs --/M Bosnian zls/bs
|
||||||
|
5 ca --/M Catalan roa/ca
|
||||||
|
5 chr-US-Qaaa-x-west --/M Cherokee_ iro/chr
|
||||||
|
5 cmn --/M Chinese_(Mandarin,_latin_as_English) sit/cmn (zh-cmn 5)(zh 5)
|
||||||
|
5 cmn-latn-pinyin --/M Chinese_(Mandarin,_latin_as_Pinyin) sit/cmn-Latn-pinyin (zh-cmn 5)(zh 5)
|
||||||
|
5 cs --/M Czech zlw/cs
|
||||||
|
5 cv --/M Chuvash trk/cv
|
||||||
|
5 cy --/M Welsh cel/cy
|
||||||
|
5 da --/M Danish gmq/da
|
||||||
|
5 de --/M German gmw/de
|
||||||
|
5 el --/M Greek grk/el
|
||||||
|
5 en-029 --/M English_(Caribbean) gmw/en-029 (en 10)
|
||||||
|
2 en-gb --/M English_(Great_Britain) gmw/en (en 2)
|
||||||
|
5 en-gb-scotland --/M English_(Scotland) gmw/en-GB-scotland (en 4)
|
||||||
|
5 en-gb-x-gbclan --/M English_(Lancaster) gmw/en-GB-x-gbclan (en-gb 3)(en 5)
|
||||||
|
5 en-gb-x-gbcwmd --/M English_(West_Midlands) gmw/en-GB-x-gbcwmd (en-gb 9)(en 9)
|
||||||
|
5 en-gb-x-rp --/M English_(Received_Pronunciation) gmw/en-GB-x-rp (en-gb 4)(en 5)
|
||||||
|
2 en-us --/M English_(America) gmw/en-US (en 3)
|
||||||
|
5 en-us-nyc --/M English_(America,_New_York_City) gmw/en-US-nyc
|
||||||
|
5 eo --/M Esperanto art/eo
|
||||||
|
5 es --/M Spanish_(Spain) roa/es
|
||||||
|
5 es-419 --/M Spanish_(Latin_America) roa/es-419 (es-mx 6)
|
||||||
|
5 et --/M Estonian urj/et
|
||||||
|
5 eu --/M Basque eu
|
||||||
|
5 fa --/M Persian ira/fa
|
||||||
|
5 fa-latn --/M Persian_(Pinglish) ira/fa-Latn
|
||||||
|
5 fi --/M Finnish urj/fi
|
||||||
|
5 fr-be --/M French_(Belgium) roa/fr-BE (fr 8)
|
||||||
|
5 fr-ch --/M French_(Switzerland) roa/fr-CH (fr 8)
|
||||||
|
5 fr-fr --/M French_(France) roa/fr (fr 5)
|
||||||
|
5 ga --/M Gaelic_(Irish) cel/ga
|
||||||
|
5 gd --/M Gaelic_(Scottish) cel/gd
|
||||||
|
5 gn --/M Guarani sai/gn
|
||||||
|
5 grc --/M Greek_(Ancient) grk/grc
|
||||||
|
5 gu --/M Gujarati inc/gu
|
||||||
|
5 hak --/M Hakka_Chinese sit/hak
|
||||||
|
5 haw --/M Hawaiian map/haw
|
||||||
|
5 he --/M Hebrew sem/he
|
||||||
|
5 hi --/M Hindi inc/hi
|
||||||
|
5 hr --/M Croatian zls/hr (hbs 5)
|
||||||
|
5 ht --/M Haitian_Creole roa/ht
|
||||||
|
5 hu --/M Hungarian urj/hu
|
||||||
|
5 hy --/M Armenian_(East_Armenia) ine/hy (hy-arevela 5)
|
||||||
|
5 hyw --/M Armenian_(West_Armenia) ine/hyw (hy-arevmda 5)(hy 8)
|
||||||
|
5 ia --/M Interlingua art/ia
|
||||||
|
5 id --/M Indonesian poz/id
|
||||||
|
5 io --/M Ido art/io
|
||||||
|
5 is --/M Icelandic gmq/is
|
||||||
|
5 it --/M Italian roa/it
|
||||||
|
5 ja --/M Japanese jpx/ja
|
||||||
|
5 jbo --/M Lojban art/jbo
|
||||||
|
5 ka --/M Georgian ccs/ka
|
||||||
|
5 kk --/M Kazakh trk/kk
|
||||||
|
5 kl --/M Greenlandic esx/kl
|
||||||
|
5 kn --/M Kannada dra/kn
|
||||||
|
5 ko --/M Korean ko
|
||||||
|
5 kok --/M Konkani inc/kok
|
||||||
|
5 ku --/M Kurdish ira/ku
|
||||||
|
5 ky --/M Kyrgyz trk/ky
|
||||||
|
5 la --/M Latin itc/la
|
||||||
|
5 lb --/M Luxembourgish gmw/lb
|
||||||
|
5 lfn --/M Lingua_Franca_Nova art/lfn
|
||||||
|
5 lt --/M Lithuanian bat/lt
|
||||||
|
5 ltg --/M Latgalian bat/ltg
|
||||||
|
5 lv --/M Latvian bat/lv
|
||||||
|
5 mi --/M Māori poz/mi
|
||||||
|
5 mk --/M Macedonian zls/mk
|
||||||
|
5 ml --/M Malayalam dra/ml
|
||||||
|
5 mr --/M Marathi inc/mr
|
||||||
|
5 ms --/M Malay poz/ms
|
||||||
|
5 mt --/M Maltese sem/mt
|
||||||
|
5 mto --/M Totontepec_Mixe miz/mto
|
||||||
|
5 my --/M Myanmar_(Burmese) sit/my
|
||||||
|
5 nb --/M Norwegian_Bokmål gmq/nb (no 5)
|
||||||
|
5 nci --/M Nahuatl_(Classical) azc/nci
|
||||||
|
5 ne --/M Nepali inc/ne
|
||||||
|
5 nl --/M Dutch gmw/nl
|
||||||
|
5 nog --/M Nogai trk/nog
|
||||||
|
5 om --/M Oromo cus/om
|
||||||
|
5 or --/M Oriya inc/or
|
||||||
|
5 pa --/M Punjabi inc/pa
|
||||||
|
5 pap --/M Papiamento roa/pap
|
||||||
|
5 piqd --/M Klingon art/piqd
|
||||||
|
5 pl --/M Polish zlw/pl
|
||||||
|
5 pt --/M Portuguese_(Portugal) roa/pt (pt-pt 5)
|
||||||
|
5 pt-br --/M Portuguese_(Brazil) roa/pt-BR (pt 6)
|
||||||
|
5 py --/M Pyash art/py
|
||||||
|
5 qdb --/M Lang_Belta art/qdb
|
||||||
|
5 qu --/M Quechua qu
|
||||||
|
5 quc --/M K'iche' myn/quc
|
||||||
|
5 qya --/M Quenya art/qya
|
||||||
|
5 ro --/M Romanian roa/ro
|
||||||
|
5 ru --/M Russian zle/ru
|
||||||
|
5 ru-cl --/M Russian_(Classic) zle/ru-cl
|
||||||
|
2 ru-lv --/M Russian_(Latvia) zle/ru-LV
|
||||||
|
5 sd --/M Sindhi inc/sd
|
||||||
|
5 shn --/M Shan_(Tai_Yai) tai/shn
|
||||||
|
5 si --/M Sinhala inc/si
|
||||||
|
5 sjn --/M Sindarin art/sjn
|
||||||
|
5 sk --/M Slovak zlw/sk
|
||||||
|
5 sl --/M Slovenian zls/sl
|
||||||
|
5 smj --/M Lule_Saami urj/smj
|
||||||
|
5 sq --/M Albanian ine/sq
|
||||||
|
5 sr --/M Serbian zls/sr
|
||||||
|
5 sv --/M Swedish gmq/sv
|
||||||
|
5 sw --/M Swahili bnt/sw
|
||||||
|
5 ta --/M Tamil dra/ta
|
||||||
|
5 te --/M Telugu dra/te
|
||||||
|
5 th --/M Thai tai/th
|
||||||
|
5 tk --/M Turkmen trk/tk
|
||||||
|
5 tn --/M Setswana bnt/tn
|
||||||
|
5 tr --/M Turkish trk/tr
|
||||||
|
5 tt --/M Tatar trk/tt
|
||||||
|
5 ug --/M Uyghur trk/ug
|
||||||
|
5 uk --/M Ukrainian zle/uk
|
||||||
|
5 ur --/M Urdu inc/ur
|
||||||
|
5 uz --/M Uzbek trk/uz
|
||||||
|
5 vi --/M Vietnamese_(Northern) aav/vi
|
||||||
|
5 vi-vn-x-central --/M Vietnamese_(Central) aav/vi-VN-x-central
|
||||||
|
5 vi-vn-x-south --/M Vietnamese_(Southern) aav/vi-VN-x-south
|
||||||
|
5 yue --/M Chinese_(Cantonese) sit/yue (zh-yue 5)(zh 8)
|
||||||
|
5 yue --/M Chinese_(Cantonese,_latin_as_Jyutping) sit/yue-Latn-jyutping (zh-yue 5)(zh 8)
|
||||||
|
"""
|
||||||
39
egs/baker_zh/TTS/vits/generate_lexicon.py
Executable file
39
egs/baker_zh/TTS/vits/generate_lexicon.py
Executable file
@ -0,0 +1,39 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from pypinyin import phrases_dict, pinyin_dict
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
filename = "lexicon.txt"
|
||||||
|
tokens = "./data/tokens.txt"
|
||||||
|
tokenizer = Tokenizer(tokens)
|
||||||
|
|
||||||
|
word_dict = pinyin_dict.pinyin_dict
|
||||||
|
phrases = phrases_dict.phrases_dict
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
for key in word_dict:
|
||||||
|
if not (0x4E00 <= key <= 0x9FFF):
|
||||||
|
continue
|
||||||
|
|
||||||
|
w = chr(key)
|
||||||
|
|
||||||
|
# 1 to remove the initial sil
|
||||||
|
# :-1 to remove the final eos
|
||||||
|
tokens = tokenizer.text_to_tokens(w)[1:-1]
|
||||||
|
|
||||||
|
tokens = " ".join(tokens)
|
||||||
|
f.write(f"{w} {tokens}\n")
|
||||||
|
|
||||||
|
for key in phrases:
|
||||||
|
# 1 to remove the initial sil
|
||||||
|
# :-1 to remove the final eos
|
||||||
|
tokens = tokenizer.text_to_tokens(key)[1:-1]
|
||||||
|
tokens = " ".join(tokens)
|
||||||
|
f.write(f"{key} {tokens}\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
egs/baker_zh/TTS/vits/pinyin_dict.py
Symbolic link
1
egs/baker_zh/TTS/vits/pinyin_dict.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../local/pinyin_dict.py
|
||||||
1
egs/baker_zh/TTS/vits/pypinyin-local.dict
Symbolic link
1
egs/baker_zh/TTS/vits/pypinyin-local.dict
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../local/pypinyin-local.dict
|
||||||
142
egs/baker_zh/TTS/vits/test_onnx.py
Executable file
142
egs/baker_zh/TTS/vits/test_onnx.py
Executable file
@ -0,0 +1,142 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script is used to test the exported onnx model by vits/export-onnx.py
|
||||||
|
|
||||||
|
Use the onnx model to generate a wav:
|
||||||
|
./vits/test_onnx.py \
|
||||||
|
--model-filename vits/exp/vits-epoch-1000.onnx \
|
||||||
|
--tokens data/tokens.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the onnx model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
help="""Path to vocabulary.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--text",
|
||||||
|
type=str,
|
||||||
|
default="Ask not what your country can do for you; ask what you can do for your country.",
|
||||||
|
help="Text to generate speech for",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-filename",
|
||||||
|
type=str,
|
||||||
|
default="test_onnx.wav",
|
||||||
|
help="Filename to save the generated wave file.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(self, model_filename: str):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
|
||||||
|
self.model = ort.InferenceSession(
|
||||||
|
model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
|
||||||
|
|
||||||
|
metadata = self.model.get_modelmeta().custom_metadata_map
|
||||||
|
self.sample_rate = int(metadata["sample_rate"])
|
||||||
|
|
||||||
|
def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
tokens:
|
||||||
|
A 1-D tensor of shape (1, T)
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (1, T')
|
||||||
|
"""
|
||||||
|
noise_scale = torch.tensor([0.667], dtype=torch.float32)
|
||||||
|
noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
|
||||||
|
alpha = torch.tensor([1.0], dtype=torch.float32)
|
||||||
|
|
||||||
|
out = self.model.run(
|
||||||
|
[
|
||||||
|
self.model.get_outputs()[0].name,
|
||||||
|
],
|
||||||
|
{
|
||||||
|
self.model.get_inputs()[0].name: tokens.numpy(),
|
||||||
|
self.model.get_inputs()[1].name: tokens_lens.numpy(),
|
||||||
|
self.model.get_inputs()[2].name: noise_scale.numpy(),
|
||||||
|
self.model.get_inputs()[3].name: alpha.numpy(),
|
||||||
|
self.model.get_inputs()[4].name: noise_scale_dur.numpy(),
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
return torch.from_numpy(out)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(args.tokens)
|
||||||
|
|
||||||
|
logging.info("About to create onnx model")
|
||||||
|
model = OnnxModel(args.model_filename)
|
||||||
|
|
||||||
|
text = args.text
|
||||||
|
tokens = tokenizer.texts_to_token_ids([text])
|
||||||
|
tokens = torch.tensor(tokens) # (1, T)
|
||||||
|
tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
|
||||||
|
audio = model(tokens, tokens_lens) # (1, T')
|
||||||
|
|
||||||
|
output_filename = args.output_filename
|
||||||
|
torchaudio.save(output_filename, audio, sample_rate=model.sample_rate)
|
||||||
|
logging.info(f"Saved to {output_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
||||||
1
egs/baker_zh/TTS/vits/tokenizer.py
Symbolic link
1
egs/baker_zh/TTS/vits/tokenizer.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../local/tokenizer.py
|
||||||
@ -1 +0,0 @@
|
|||||||
../../../ljspeech/TTS/vits/train.py
|
|
||||||
927
egs/baker_zh/TTS/vits/train.py
Executable file
927
egs/baker_zh/TTS/vits/train.py
Executable file
@ -0,0 +1,927 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
from lhotse.cut import Cut
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tts_datamodule import BakerZhSpeechTtsDataModule
|
||||||
|
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||||
|
from vits import VITS
|
||||||
|
|
||||||
|
from icefall import diagnostics
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
|
from icefall.env import get_env_info
|
||||||
|
from icefall.hooks import register_inf_check_hooks
|
||||||
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
|
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--world-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of GPUs for DDP training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--master-port",
|
||||||
|
type=int,
|
||||||
|
default=12354,
|
||||||
|
help="Master port to use for DDP training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tensorboard",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Should various information be logged in tensorboard.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-epochs",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of epochs to train.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--start-epoch",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Resume training from this epoch. It should be positive.
|
||||||
|
If larger than 1, it will load checkpoint from
|
||||||
|
exp-dir/epoch-{start_epoch-1}.pt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="vits/exp",
|
||||||
|
help="""The experiment dir.
|
||||||
|
It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="data/tokens.txt",
|
||||||
|
help="""Path to vocabulary.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr", type=float, default=2.0e-4, help="The base learning rate."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--print-diagnostics",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Accumulate stats on activations, print them and exit.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--inf-check",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Add hooks to check for infinite module outputs and gradients.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-every-n",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="""Save checkpoint after processing this number of epochs"
|
||||||
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
|
params.cur_epoch % save_every_n == 0. The checkpoint filename
|
||||||
|
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
|
||||||
|
Since it will take around 1000 epochs, we suggest using a large
|
||||||
|
save_every_n to save disk space.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fp16",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use half precision training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-type",
|
||||||
|
type=str,
|
||||||
|
default="high",
|
||||||
|
choices=["low", "medium", "high"],
|
||||||
|
help="""If not empty, valid values are: low, medium, high.
|
||||||
|
It controls the model size. low -> runs faster.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
"""Return a dict containing training parameters.
|
||||||
|
|
||||||
|
All training related parameters that are not passed from the commandline
|
||||||
|
are saved in the variable `params`.
|
||||||
|
|
||||||
|
Commandline options are merged into `params` after they are parsed, so
|
||||||
|
you can also access them via `params`.
|
||||||
|
|
||||||
|
Explanation of options saved in `params`:
|
||||||
|
|
||||||
|
- best_train_loss: Best training loss so far. It is used to select
|
||||||
|
the model that has the lowest training loss. It is
|
||||||
|
updated during the training.
|
||||||
|
|
||||||
|
- best_valid_loss: Best validation loss so far. It is used to select
|
||||||
|
the model that has the lowest validation loss. It is
|
||||||
|
updated during the training.
|
||||||
|
|
||||||
|
- best_train_epoch: It is the epoch that has the best training loss.
|
||||||
|
|
||||||
|
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||||
|
|
||||||
|
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||||
|
contains number of batches trained so far across
|
||||||
|
epochs.
|
||||||
|
|
||||||
|
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||||
|
|
||||||
|
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||||
|
|
||||||
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
|
in computing features.
|
||||||
|
"""
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
# training params
|
||||||
|
"best_train_loss": float("inf"),
|
||||||
|
"best_valid_loss": float("inf"),
|
||||||
|
"best_train_epoch": -1,
|
||||||
|
"best_valid_epoch": -1,
|
||||||
|
"batch_idx_train": -1, # 0
|
||||||
|
"log_interval": 50,
|
||||||
|
"valid_interval": 200,
|
||||||
|
"env_info": get_env_info(),
|
||||||
|
"sampling_rate": 48000,
|
||||||
|
"frame_shift": 256,
|
||||||
|
"frame_length": 1024,
|
||||||
|
"feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
|
||||||
|
"n_mels": 80,
|
||||||
|
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
||||||
|
"lambda_mel": 45.0, # loss scaling coefficient for Mel loss
|
||||||
|
"lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
|
||||||
|
"lambda_dur": 1.0, # loss scaling coefficient for duration loss
|
||||||
|
"lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_if_available(
|
||||||
|
params: AttributeDict, model: nn.Module
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
|
If params.start_epoch is larger than 1, it will load the checkpoint from
|
||||||
|
`params.start_epoch - 1`.
|
||||||
|
|
||||||
|
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||||
|
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||||
|
and `best_valid_loss` in `params`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
The return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The training model.
|
||||||
|
Returns:
|
||||||
|
Return a dict containing previously saved training info.
|
||||||
|
"""
|
||||||
|
if params.start_epoch > 1:
|
||||||
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert filename.is_file(), f"{filename} does not exist!"
|
||||||
|
|
||||||
|
saved_params = load_checkpoint(filename, model=model)
|
||||||
|
|
||||||
|
keys = [
|
||||||
|
"best_train_epoch",
|
||||||
|
"best_valid_epoch",
|
||||||
|
"batch_idx_train",
|
||||||
|
"best_train_loss",
|
||||||
|
"best_valid_loss",
|
||||||
|
]
|
||||||
|
for k in keys:
|
||||||
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
|
mel_loss_params = {
|
||||||
|
"n_mels": params.n_mels,
|
||||||
|
"frame_length": params.frame_length,
|
||||||
|
"frame_shift": params.frame_shift,
|
||||||
|
}
|
||||||
|
model = VITS(
|
||||||
|
vocab_size=params.vocab_size,
|
||||||
|
feature_dim=params.feature_dim,
|
||||||
|
sampling_rate=params.sampling_rate,
|
||||||
|
model_type=params.model_type,
|
||||||
|
mel_loss_params=mel_loss_params,
|
||||||
|
lambda_adv=params.lambda_adv,
|
||||||
|
lambda_mel=params.lambda_mel,
|
||||||
|
lambda_feat_match=params.lambda_feat_match,
|
||||||
|
lambda_dur=params.lambda_dur,
|
||||||
|
lambda_kl=params.lambda_kl,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||||
|
"""Parse batch data"""
|
||||||
|
audio = batch["audio"].to(device)
|
||||||
|
features = batch["features"].to(device)
|
||||||
|
audio_lens = batch["audio_lens"].to(device)
|
||||||
|
features_lens = batch["features_lens"].to(device)
|
||||||
|
tokens = batch["tokens"]
|
||||||
|
|
||||||
|
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||||
|
tokens = k2.RaggedTensor(tokens)
|
||||||
|
row_splits = tokens.shape.row_splits(1)
|
||||||
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
tokens = tokens.to(device)
|
||||||
|
tokens_lens = tokens_lens.to(device)
|
||||||
|
# a tensor of shape (B, T)
|
||||||
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
||||||
|
|
||||||
|
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
optimizer_g: Optimizer,
|
||||||
|
optimizer_d: Optimizer,
|
||||||
|
scheduler_g: LRSchedulerType,
|
||||||
|
scheduler_d: LRSchedulerType,
|
||||||
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
scaler: GradScaler,
|
||||||
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
|
world_size: int = 1,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Train the model for one epoch.
|
||||||
|
|
||||||
|
The training loss from the mean of all frames is saved in
|
||||||
|
`params.train_loss`. It runs the validation process every
|
||||||
|
`params.valid_interval` batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The model for training.
|
||||||
|
tokenizer:
|
||||||
|
Used to convert text to phonemes.
|
||||||
|
optimizer_g:
|
||||||
|
The optimizer for generator.
|
||||||
|
optimizer_d:
|
||||||
|
The optimizer for discriminator.
|
||||||
|
scheduler_g:
|
||||||
|
The learning rate scheduler for generator, we call step() every epoch.
|
||||||
|
scheduler_d:
|
||||||
|
The learning rate scheduler for discriminator, we call step() every epoch.
|
||||||
|
train_dl:
|
||||||
|
Dataloader for the training dataset.
|
||||||
|
valid_dl:
|
||||||
|
Dataloader for the validation dataset.
|
||||||
|
scaler:
|
||||||
|
The scaler used for mix precision training.
|
||||||
|
tb_writer:
|
||||||
|
Writer to write log messages to tensorboard.
|
||||||
|
world_size:
|
||||||
|
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||||
|
rank:
|
||||||
|
The rank of the node in DDP training. If no DDP is used, it should
|
||||||
|
be set to 0.
|
||||||
|
"""
|
||||||
|
model.train()
|
||||||
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
|
||||||
|
# used to track the stats over iterations in one epoch
|
||||||
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
|
saved_bad_model = False
|
||||||
|
|
||||||
|
def save_bad_model(suffix: str = ""):
|
||||||
|
save_checkpoint(
|
||||||
|
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
optimizer_g=optimizer_g,
|
||||||
|
optimizer_d=optimizer_d,
|
||||||
|
scheduler_g=scheduler_g,
|
||||||
|
scheduler_d=scheduler_d,
|
||||||
|
sampler=train_dl.sampler,
|
||||||
|
scaler=scaler,
|
||||||
|
rank=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
params.batch_idx_train += 1
|
||||||
|
|
||||||
|
batch_size = len(batch["tokens"])
|
||||||
|
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||||
|
batch, tokenizer, device
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_info = MetricsTracker()
|
||||||
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
|
try:
|
||||||
|
with autocast(enabled=params.use_fp16):
|
||||||
|
# forward discriminator
|
||||||
|
loss_d, stats_d = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=False,
|
||||||
|
)
|
||||||
|
for k, v in stats_d.items():
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
# update discriminator
|
||||||
|
optimizer_d.zero_grad()
|
||||||
|
scaler.scale(loss_d).backward()
|
||||||
|
scaler.step(optimizer_d)
|
||||||
|
|
||||||
|
with autocast(enabled=params.use_fp16):
|
||||||
|
# forward generator
|
||||||
|
loss_g, stats_g = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=True,
|
||||||
|
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||||
|
)
|
||||||
|
for k, v in stats_g.items():
|
||||||
|
if "returned_sample" not in k:
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
# update generator
|
||||||
|
optimizer_g.zero_grad()
|
||||||
|
scaler.scale(loss_g).backward()
|
||||||
|
scaler.step(optimizer_g)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
# summary stats
|
||||||
|
tot_loss = tot_loss + loss_info
|
||||||
|
except: # noqa
|
||||||
|
save_bad_model()
|
||||||
|
raise
|
||||||
|
|
||||||
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
|
return
|
||||||
|
|
||||||
|
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||||
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
|
# behavior depending on the current grad scale.
|
||||||
|
cur_grad_scale = scaler._scale.item()
|
||||||
|
|
||||||
|
if cur_grad_scale < 8.0 or (
|
||||||
|
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
|
||||||
|
):
|
||||||
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
|
if cur_grad_scale < 0.01:
|
||||||
|
if not saved_bad_model:
|
||||||
|
save_bad_model(suffix="-first-warning")
|
||||||
|
saved_bad_model = True
|
||||||
|
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||||
|
if cur_grad_scale < 1.0e-05:
|
||||||
|
save_bad_model()
|
||||||
|
raise RuntimeError(
|
||||||
|
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.batch_idx_train % params.log_interval == 0:
|
||||||
|
cur_lr_g = max(scheduler_g.get_last_lr())
|
||||||
|
cur_lr_d = max(scheduler_d.get_last_lr())
|
||||||
|
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
|
f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
|
||||||
|
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||||
|
f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
|
||||||
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate_g", cur_lr_g, params.batch_idx_train
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate_d", cur_lr_d, params.batch_idx_train
|
||||||
|
)
|
||||||
|
loss_info.write_summary(
|
||||||
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
|
if params.use_fp16:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
|
)
|
||||||
|
if "returned_sample" in stats_g:
|
||||||
|
speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/speech_hat_",
|
||||||
|
speech_hat_,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
|
)
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/speech_",
|
||||||
|
speech_,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
|
)
|
||||||
|
tb_writer.add_image(
|
||||||
|
"train/mel_hat_",
|
||||||
|
plot_feature(mel_hat_),
|
||||||
|
params.batch_idx_train,
|
||||||
|
dataformats="HWC",
|
||||||
|
)
|
||||||
|
tb_writer.add_image(
|
||||||
|
"train/mel_",
|
||||||
|
plot_feature(mel_),
|
||||||
|
params.batch_idx_train,
|
||||||
|
dataformats="HWC",
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
params.batch_idx_train % params.valid_interval == 0
|
||||||
|
and not params.print_diagnostics
|
||||||
|
):
|
||||||
|
logging.info("Computing validation loss")
|
||||||
|
valid_info, (speech_hat, speech) = compute_validation_loss(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
model.train()
|
||||||
|
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
|
if tb_writer is not None:
|
||||||
|
valid_info.write_summary(
|
||||||
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/valdi_speech_hat",
|
||||||
|
speech_hat,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
|
)
|
||||||
|
tb_writer.add_audio(
|
||||||
|
"train/valdi_speech",
|
||||||
|
speech,
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||||
|
params.train_loss = loss_value
|
||||||
|
if params.train_loss < params.best_train_loss:
|
||||||
|
params.best_train_epoch = params.cur_epoch
|
||||||
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
|
||||||
|
def compute_validation_loss(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
world_size: int = 1,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
|
||||||
|
"""Run the validation process."""
|
||||||
|
model.eval()
|
||||||
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
|
||||||
|
# used to summary the stats over iterations
|
||||||
|
tot_loss = MetricsTracker()
|
||||||
|
returned_sample = None
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
|
batch_size = len(batch["tokens"])
|
||||||
|
(
|
||||||
|
audio,
|
||||||
|
audio_lens,
|
||||||
|
features,
|
||||||
|
features_lens,
|
||||||
|
tokens,
|
||||||
|
tokens_lens,
|
||||||
|
) = prepare_input(batch, tokenizer, device)
|
||||||
|
|
||||||
|
loss_info = MetricsTracker()
|
||||||
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
|
# forward discriminator
|
||||||
|
loss_d, stats_d = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=False,
|
||||||
|
)
|
||||||
|
assert loss_d.requires_grad is False
|
||||||
|
for k, v in stats_d.items():
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
|
||||||
|
# forward generator
|
||||||
|
loss_g, stats_g = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=True,
|
||||||
|
)
|
||||||
|
assert loss_g.requires_grad is False
|
||||||
|
for k, v in stats_g.items():
|
||||||
|
loss_info[k] = v * batch_size
|
||||||
|
|
||||||
|
# summary stats
|
||||||
|
tot_loss = tot_loss + loss_info
|
||||||
|
|
||||||
|
# infer for first batch:
|
||||||
|
if batch_idx == 0 and rank == 0:
|
||||||
|
inner_model = model.module if isinstance(model, DDP) else model
|
||||||
|
audio_pred, _, duration = inner_model.inference(
|
||||||
|
text=tokens[0, : tokens_lens[0].item()]
|
||||||
|
)
|
||||||
|
audio_pred = audio_pred.data.cpu().numpy()
|
||||||
|
audio_len_pred = (
|
||||||
|
(duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
|
||||||
|
)
|
||||||
|
assert audio_len_pred == len(audio_pred), (
|
||||||
|
audio_len_pred,
|
||||||
|
len(audio_pred),
|
||||||
|
)
|
||||||
|
audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy()
|
||||||
|
returned_sample = (audio_pred, audio_gt)
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
tot_loss.reduce(device)
|
||||||
|
|
||||||
|
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||||
|
if loss_value < params.best_valid_loss:
|
||||||
|
params.best_valid_epoch = params.cur_epoch
|
||||||
|
params.best_valid_loss = loss_value
|
||||||
|
|
||||||
|
return tot_loss, returned_sample
|
||||||
|
|
||||||
|
|
||||||
|
def scan_pessimistic_batches_for_oom(
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
optimizer_g: torch.optim.Optimizer,
|
||||||
|
optimizer_d: torch.optim.Optimizer,
|
||||||
|
params: AttributeDict,
|
||||||
|
):
|
||||||
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
||||||
|
)
|
||||||
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||||
|
for criterion, cuts in batches.items():
|
||||||
|
batch = train_dl.dataset[cuts]
|
||||||
|
audio, audio_lens, features, features_lens, tokens, tokens_lens = prepare_input(
|
||||||
|
batch, tokenizer, device
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# for discriminator
|
||||||
|
with autocast(enabled=params.use_fp16):
|
||||||
|
loss_d, stats_d = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=False,
|
||||||
|
)
|
||||||
|
optimizer_d.zero_grad()
|
||||||
|
loss_d.backward()
|
||||||
|
# for generator
|
||||||
|
with autocast(enabled=params.use_fp16):
|
||||||
|
loss_g, stats_g = model(
|
||||||
|
text=tokens,
|
||||||
|
text_lengths=tokens_lens,
|
||||||
|
feats=features,
|
||||||
|
feats_lengths=features_lens,
|
||||||
|
speech=audio,
|
||||||
|
speech_lengths=audio_lens,
|
||||||
|
forward_generator=True,
|
||||||
|
)
|
||||||
|
optimizer_g.zero_grad()
|
||||||
|
loss_g.backward()
|
||||||
|
except Exception as e:
|
||||||
|
if "CUDA out of memory" in str(e):
|
||||||
|
logging.error(
|
||||||
|
"Your GPU ran out of memory with the current "
|
||||||
|
"max_duration setting. We recommend decreasing "
|
||||||
|
"max_duration and trying again.\n"
|
||||||
|
f"Failing criterion: {criterion} "
|
||||||
|
f"(={crit_values[criterion]}) ..."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
logging.info(
|
||||||
|
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run(rank, world_size, args):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
rank:
|
||||||
|
It is a value between 0 and `world_size-1`, which is
|
||||||
|
passed automatically by `mp.spawn()` in :func:`main`.
|
||||||
|
The node with rank 0 is responsible for saving checkpoint.
|
||||||
|
world_size:
|
||||||
|
Number of GPUs for DDP training.
|
||||||
|
args:
|
||||||
|
The return value of get_parser().parse_args()
|
||||||
|
"""
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
fix_random_seed(params.seed)
|
||||||
|
if world_size > 1:
|
||||||
|
setup_dist(rank, world_size, params.master_port)
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
|
logging.info("Training started")
|
||||||
|
|
||||||
|
if args.tensorboard and rank == 0:
|
||||||
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
|
else:
|
||||||
|
tb_writer = None
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", rank)
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(params.tokens)
|
||||||
|
params.blank_id = tokenizer.pad_id
|
||||||
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
generator = model.generator
|
||||||
|
discriminator = model.discriminator
|
||||||
|
|
||||||
|
num_param_g = sum([p.numel() for p in generator.parameters()])
|
||||||
|
logging.info(f"Number of parameters in generator: {num_param_g}")
|
||||||
|
num_param_d = sum([p.numel() for p in discriminator.parameters()])
|
||||||
|
logging.info(f"Number of parameters in discriminator: {num_param_d}")
|
||||||
|
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
|
||||||
|
|
||||||
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
if world_size > 1:
|
||||||
|
logging.info("Using DDP")
|
||||||
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
|
optimizer_g = torch.optim.AdamW(
|
||||||
|
generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
||||||
|
)
|
||||||
|
optimizer_d = torch.optim.AdamW(
|
||||||
|
discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
|
||||||
|
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
|
||||||
|
|
||||||
|
if checkpoints is not None:
|
||||||
|
# load state_dict for optimizers
|
||||||
|
if "optimizer_g" in checkpoints:
|
||||||
|
logging.info("Loading optimizer_g state dict")
|
||||||
|
optimizer_g.load_state_dict(checkpoints["optimizer_g"])
|
||||||
|
if "optimizer_d" in checkpoints:
|
||||||
|
logging.info("Loading optimizer_d state dict")
|
||||||
|
optimizer_d.load_state_dict(checkpoints["optimizer_d"])
|
||||||
|
|
||||||
|
# load state_dict for schedulers
|
||||||
|
if "scheduler_g" in checkpoints:
|
||||||
|
logging.info("Loading scheduler_g state dict")
|
||||||
|
scheduler_g.load_state_dict(checkpoints["scheduler_g"])
|
||||||
|
if "scheduler_d" in checkpoints:
|
||||||
|
logging.info("Loading scheduler_d state dict")
|
||||||
|
scheduler_d.load_state_dict(checkpoints["scheduler_d"])
|
||||||
|
|
||||||
|
if params.print_diagnostics:
|
||||||
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
|
512
|
||||||
|
) # allow 4 megabytes per sub-module
|
||||||
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
if params.inf_check:
|
||||||
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
|
baker_zh = BakerZhSpeechTtsDataModule(args)
|
||||||
|
|
||||||
|
train_cuts = baker_zh.train_cuts()
|
||||||
|
|
||||||
|
def remove_short_and_long_utt(c: Cut):
|
||||||
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
|
# You should use ../local/display_manifest_statistics.py to get
|
||||||
|
# an utterance duration distribution for your dataset to select
|
||||||
|
# the threshold
|
||||||
|
if c.duration < 1.0 or c.duration > 20.0:
|
||||||
|
# logging.warning(
|
||||||
|
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
|
# )
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
train_dl = baker_zh.train_dataloaders(train_cuts)
|
||||||
|
|
||||||
|
valid_cuts = baker_zh.valid_cuts()
|
||||||
|
valid_dl = baker_zh.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
|
if not params.print_diagnostics:
|
||||||
|
scan_pessimistic_batches_for_oom(
|
||||||
|
model=model,
|
||||||
|
train_dl=train_dl,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
optimizer_g=optimizer_g,
|
||||||
|
optimizer_d=optimizer_d,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
logging.info("Loading grad scaler state dict")
|
||||||
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
|
|
||||||
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
|
logging.info(f"Start epoch {epoch}")
|
||||||
|
|
||||||
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
|
|
||||||
|
params.cur_epoch = epoch
|
||||||
|
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
|
train_one_epoch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
optimizer_g=optimizer_g,
|
||||||
|
optimizer_d=optimizer_d,
|
||||||
|
scheduler_g=scheduler_g,
|
||||||
|
scheduler_d=scheduler_d,
|
||||||
|
train_dl=train_dl,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
scaler=scaler,
|
||||||
|
tb_writer=tb_writer,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.print_diagnostics:
|
||||||
|
diagnostic.print_diagnostics()
|
||||||
|
break
|
||||||
|
|
||||||
|
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
|
||||||
|
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||||
|
save_checkpoint(
|
||||||
|
filename=filename,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer_g=optimizer_g,
|
||||||
|
optimizer_d=optimizer_d,
|
||||||
|
scheduler_g=scheduler_g,
|
||||||
|
scheduler_d=scheduler_d,
|
||||||
|
sampler=train_dl.sampler,
|
||||||
|
scaler=scaler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
if rank == 0:
|
||||||
|
if params.best_train_epoch == params.cur_epoch:
|
||||||
|
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_train_filename)
|
||||||
|
|
||||||
|
if params.best_valid_epoch == params.cur_epoch:
|
||||||
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
# step per epoch
|
||||||
|
scheduler_g.step()
|
||||||
|
scheduler_d.step()
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
BakerZhSpeechTtsDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
world_size = args.world_size
|
||||||
|
assert world_size >= 1
|
||||||
|
if world_size > 1:
|
||||||
|
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||||
|
else:
|
||||||
|
run(rank=0, world_size=1, args=args)
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -52,7 +52,7 @@ class _SeedWorkers:
|
|||||||
fix_random_seed(self.seed + worker_id)
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
class LJSpeechTtsDataModule:
|
class BakerZhSpeechTtsDataModule:
|
||||||
"""
|
"""
|
||||||
DataModule for tts experiments.
|
DataModule for tts experiments.
|
||||||
It assumes there is always one train and valid dataloader,
|
It assumes there is always one train and valid dataloader,
|
||||||
@ -66,11 +66,12 @@ class LJSpeechTtsDataModule:
|
|||||||
- cut concatenation,
|
- cut concatenation,
|
||||||
- on-the-fly feature extraction
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
This class should be derived for specific corpora used in ASR tasks.
|
This class should be derived for specific corpora used in TTS tasks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args: argparse.Namespace):
|
def __init__(self, args: argparse.Namespace):
|
||||||
self.args = args
|
self.args = args
|
||||||
|
self.sampling_rate = 48000
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
@ -175,7 +176,7 @@ class LJSpeechTtsDataModule:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = 22050
|
sampling_rate = self.sampling_rate
|
||||||
config = SpectrogramConfig(
|
config = SpectrogramConfig(
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
frame_length=1024 / sampling_rate, # (in second),
|
frame_length=1024 / sampling_rate, # (in second),
|
||||||
@ -232,7 +233,7 @@ class LJSpeechTtsDataModule:
|
|||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = 22050
|
sampling_rate = self.sampling_rate
|
||||||
config = SpectrogramConfig(
|
config = SpectrogramConfig(
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
frame_length=1024 / sampling_rate, # (in second),
|
frame_length=1024 / sampling_rate, # (in second),
|
||||||
@ -272,7 +273,7 @@ class LJSpeechTtsDataModule:
|
|||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.info("About to create test dataset")
|
logging.info("About to create test dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
sampling_rate = 22050
|
sampling_rate = self.sampling_rate
|
||||||
config = SpectrogramConfig(
|
config = SpectrogramConfig(
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
frame_length=1024 / sampling_rate, # (in second),
|
frame_length=1024 / sampling_rate, # (in second),
|
||||||
@ -311,19 +312,19 @@ class LJSpeechTtsDataModule:
|
|||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
|
self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to get validation cuts")
|
logging.info("About to get validation cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
|
self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_cuts(self) -> CutSet:
|
def test_cuts(self) -> CutSet:
|
||||||
logging.info("About to get test cuts")
|
logging.info("About to get test cuts")
|
||||||
return load_manifest_lazy(
|
return load_manifest_lazy(
|
||||||
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
|
self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py
|
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py
|
||||||
"""Setup cython code."""
|
"""Setup cython code."""
|
||||||
|
|
||||||
from Cython.Build import cythonize
|
try:
|
||||||
|
from Cython.Build import cythonize
|
||||||
|
except ModuleNotFoundError as ex:
|
||||||
|
raise RuntimeError(f'{ex}\nPlease run:\n pip install cython')
|
||||||
from setuptools import Extension, setup
|
from setuptools import Extension, setup
|
||||||
from setuptools.command.build_ext import build_ext as _build_ext
|
from setuptools.command.build_ext import build_ext as _build_ext
|
||||||
|
|
||||||
|
|||||||
@ -44,11 +44,11 @@ class Tokenizer(object):
|
|||||||
if len(info) == 1:
|
if len(info) == 1:
|
||||||
# case of space
|
# case of space
|
||||||
token = " "
|
token = " "
|
||||||
id = int(info[0])
|
idx = int(info[0])
|
||||||
else:
|
else:
|
||||||
token, id = info[0], int(info[1])
|
token, idx = info[0], int(info[1])
|
||||||
assert token not in self.token2id, token
|
assert token not in self.token2id, token
|
||||||
self.token2id[token] = id
|
self.token2id[token] = idx
|
||||||
|
|
||||||
# Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
|
# Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
|
||||||
self.pad_id = self.token2id["_"] # padding
|
self.pad_id = self.token2id["_"] # padding
|
||||||
|
|||||||
@ -66,7 +66,7 @@ class LJSpeechTtsDataModule:
|
|||||||
- cut concatenation,
|
- cut concatenation,
|
||||||
- on-the-fly feature extraction
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
This class should be derived for specific corpora used in ASR tasks.
|
This class should be derived for specific corpora used in TTS tasks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args: argparse.Namespace):
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user