Add MatchaTTS for the Chinese dataset Baker (#1849)

This commit is contained in:
Fangjun Kuang 2024-12-31 17:17:05 +08:00 committed by GitHub
parent df46a3eaf9
commit bfffda5afb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 3128 additions and 12 deletions

167
.github/scripts/baker_zh/TTS/run-matcha.sh vendored Executable file
View File

@ -0,0 +1,167 @@
#!/usr/bin/env bash
set -ex
apt-get update
apt-get install -y sox
python3 -m pip install numba conformer==0.3.2 diffusers librosa
python3 -m pip install jieba
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/baker_zh/TTS
sed -i.bak s/600/8/g ./prepare.sh
sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
sed -i.bak s/500/5/g ./prepare.sh
git diff
function prepare_data() {
# We have created a subset of the data for testing
#
mkdir -p download
pushd download
wget -q https://huggingface.co/csukuangfj/tmp-files/resolve/main/BZNSYP-samples.tar.bz2
tar xvf BZNSYP-samples.tar.bz2
mv BZNSYP-samples BZNSYP
rm BZNSYP-samples.tar.bz2
popd
./prepare.sh
tree .
}
function train() {
pushd ./matcha
sed -i.bak s/1500/3/g ./train.py
git diff .
popd
./matcha/train.py \
--exp-dir matcha/exp \
--num-epochs 1 \
--save-every-n 1 \
--num-buckets 2 \
--tokens data/tokens.txt \
--max-duration 20
ls -lh matcha/exp
}
function infer() {
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2
./matcha/infer.py \
--num-buckets 2 \
--epoch 1 \
--exp-dir ./matcha/exp \
--tokens data/tokens.txt \
--cmvn ./data/fbank/cmvn.json \
--vocoder ./generator_v2 \
--input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \
--output-wav ./generated.wav
ls -lh *.wav
soxi ./generated.wav
rm -v ./generated.wav
rm -v generator_v2
}
function export_onnx() {
pushd matcha/exp
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27/resolve/main/epoch-2000.pt
popd
pushd data/fbank
rm -v *.json
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27/resolve/main/cmvn.json
popd
./matcha/export_onnx.py \
--exp-dir ./matcha/exp \
--epoch 2000 \
--tokens ./data/tokens.txt \
--cmvn ./data/fbank/cmvn.json
ls -lh *.onnx
if false; then
# The CI machine does not have enough memory to run it
#
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3
python3 ./matcha/export_onnx_hifigan.py
else
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v2.onnx
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v3.onnx
fi
ls -lh *.onnx
python3 ./matcha/generate_lexicon.py
for v in v1 v2 v3; do
python3 ./matcha/onnx_pretrained.py \
--acoustic-model ./model-steps-6.onnx \
--vocoder ./hifigan_$v.onnx \
--tokens ./data/tokens.txt \
--lexicon ./lexicon.txt \
--input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \
--output-wav /icefall/generated-matcha-tts-steps-6-$v.wav
done
ls -lh /icefall/*.wav
soxi /icefall/generated-matcha-tts-steps-6-*.wav
cp ./model-steps-*.onnx /icefall
d=matcha-icefall-zh-baker
mkdir $d
cp -v data/tokens.txt $d
cp -v lexicon.txt $d
cp model-steps-3.onnx $d
pushd $d
curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2
tar xvf dict.tar.bz2
rm dict.tar.bz2
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst
cat >README.md <<EOF
# Introduction
This model is trained using the dataset from
https://en.data-baker.com/datasets/freeDatasets/
The dataset contains 10000 Chinese sentences of a native Chinese female speaker,
which is about 12 hours.
**Note**: The dataset is for non-commercial use only.
You can find the training code at
https://github.com/k2-fsa/icefall/tree/master/egs/baker_zh/TTS
EOF
ls -lh
popd
tar cvjf $d.tar.bz2 $d
mv $d.tar.bz2 /icefall
mv $d /icefall
}
prepare_data
train
infer
export_onnx
rm -rfv generator_v* matcha/exp
git checkout .

View File

@ -2,9 +2,19 @@
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import json
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--min-torch-version",
help="Minimu torch version",
)
return parser.parse_args()
def version_gt(a, b):
a_major, a_minor = list(map(int, a.split(".")))[:2]
b_major, b_minor = list(map(int, b.split(".")))[:2]
@ -42,7 +52,7 @@ def get_torchaudio_version(torch_version):
return torch_version
def get_matrix():
def get_matrix(min_torch_version):
k2_version = "1.24.4.dev20241029"
kaldifeat_version = "1.25.5.dev20241029"
version = "20241218"
@ -64,6 +74,9 @@ def get_matrix():
matrix = []
for p in python_version:
for t in torch_version:
if min_torch_version and version_gt(min_torch_version, t):
continue
# torchaudio <= 1.13.x supports only python <= 3.10
if version_gt(p, "3.10") and not version_gt(t, "2.0"):
@ -101,7 +114,8 @@ def get_matrix():
def main():
matrix = get_matrix()
args = get_args()
matrix = get_matrix(min_torch_version=args.min_torch_version)
print(json.dumps({"include": matrix}))

View File

@ -90,7 +90,7 @@ function export_onnx() {
ls -lh *.onnx
if false; then
# THe CI machine does not have enough memory to run it
# The CI machine does not have enough memory to run it
#
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2

152
.github/workflows/baker_zh.yml vendored Normal file
View File

@ -0,0 +1,152 @@
name: baker_zh
on:
push:
branches:
- master
- baker-matcha-2
pull_request:
branches:
- master
workflow_dispatch:
concurrency:
group: baker-zh-${{ github.ref }}
cancel-in-progress: true
jobs:
generate_build_matrix:
if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa'
# see https://github.com/pytorch/pytorch/pull/50633
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Generating build matrix
id: set-matrix
run: |
# outputting for debugging purposes
python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3"
MATRIX=$(python ./.github/scripts/docker/generate_build_matrix.py --min-torch-version "2.3")
echo "::set-output name=matrix::${MATRIX}"
baker_zh:
needs: generate_build_matrix
name: py${{ matrix.python-version }} torch${{ matrix.torch-version }} v${{ matrix.version }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
${{ fromJson(needs.generate_build_matrix.outputs.matrix) }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Free space
shell: bash
run: |
ls -lh
df -h
rm -rf /opt/hostedtoolcache
df -h
echo "pwd: $PWD"
echo "github.workspace ${{ github.workspace }}"
- name: Run tests
uses: addnab/docker-run-action@v3
with:
image: ghcr.io/${{ github.repository_owner }}/icefall:cpu-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-v${{ matrix.version }}
options: |
--volume ${{ github.workspace }}/:/icefall
shell: bash
run: |
export PYTHONPATH=/icefall:$PYTHONPATH
cd /icefall
pip install onnx==1.17.0
pip list
git config --global --add safe.directory /icefall
.github/scripts/baker_zh/TTS/run-matcha.sh
- name: display files
shell: bash
run: |
ls -lh
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
path: ./*.wav
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-2
path: ./model-steps-2.onnx
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-3
path: ./model-steps-3.onnx
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-4
path: ./model-steps-4.onnx
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-5
path: ./model-steps-5.onnx
- uses: actions/upload-artifact@v4
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
with:
name: step-6
path: ./model-steps-6.onnx
- name: Upload models to huggingface
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
d=matcha-icefall-zh-baker
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/$d hf
cp -av $d/* hf/
pushd hf
git add .
git config --global user.name "csukuangfj"
git config --global user.email "csukuangfj@gmail.com"
git config --global lfs.allowincompletepush true
git commit -m "upload model" && git push https://csukuangfj:${HF_TOKEN}@huggingface.co/csukuangfj/$d main || true
popd
- name: Release exported onnx models
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
overwrite: true
file: matcha-icefall-*.tar.bz2
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: tts-models

6
egs/baker_zh/TTS/.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
path.sh
*.onnx
*.wav
generator_v1
generator_v2
generator_v3

146
egs/baker_zh/TTS/README.md Normal file
View File

@ -0,0 +1,146 @@
# Introduction
It is for the dataset from
https://en.data-baker.com/datasets/freeDatasets/
The dataset contains 10000 Chinese sentences of a native Chinese female speaker,
which is about 12 hours.
**Note**: The dataset is for non-commercial use only.
# matcha
[./matcha](./matcha) contains the code for training [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS)
Checkpoints and training logs can be found [here](https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27).
The pull-request for this recipe can be found at <https://github.com/k2-fsa/icefall/pull/1849>
The training command is given below:
```bash
python3 ./matcha/train.py \
--exp-dir ./matcha/exp-1/ \
--num-workers 4 \
--world-size 1 \
--num-epochs 2000 \
--max-duration 1200 \
--bucketing-sampler 1 \
--start-epoch 1
```
To inference, use:
```bash
# Download Hifigan vocoder. We use Hifigan v2 below. You can select from v1, v2, or v3
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2
python3 ./matcha/infer.py \
--epoch 2000 \
--exp-dir ./matcha/exp-1 \
--vocoder ./generator_v2 \
--tokens ./data/tokens.txt \
--cmvn ./data/fbank/cmvn.json \
--input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \
--output-wav ./generated.wav
```
```bash
soxi ./generated.wav
```
prints:
```
Input File : './generated.wav'
Channels : 1
Sample Rate : 22050
Precision : 16-bit
Duration : 00:00:17.31 = 381696 samples ~ 1298.29 CDDA sectors
File Size : 763k
Bit Rate : 353k
Sample Encoding: 16-bit Signed Integer PCM
```
https://github.com/user-attachments/assets/88d4e88f-ebc4-4f32-b216-16d46b966024
To export the checkpoint to onnx:
```bash
python3 ./matcha/export_onnx.py \
--exp-dir ./matcha/exp-1 \
--epoch 2000 \
--tokens ./data/tokens.txt \
--cmvn ./data/fbank/cmvn.json
```
The above command generates the following files:
```
-rw-r--r-- 1 kuangfangjun root 72M Dec 27 18:53 model-steps-2.onnx
-rw-r--r-- 1 kuangfangjun root 73M Dec 27 18:54 model-steps-3.onnx
-rw-r--r-- 1 kuangfangjun root 73M Dec 27 18:54 model-steps-4.onnx
-rw-r--r-- 1 kuangfangjun root 74M Dec 27 18:55 model-steps-5.onnx
-rw-r--r-- 1 kuangfangjun root 74M Dec 27 18:57 model-steps-6.onnx
```
where the 2 in `model-steps-2.onnx` means it uses 2 steps for the ODE solver.
**HINT**: If you get the following error while running `export_onnx.py`:
```
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator
'aten::scaled_dot_product_attention' to ONNX opset version 14 is not supported.
```
please use `torch>=2.2.0`.
To export the Hifigan vocoder to onnx, please use:
```bash
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3
python3 ./matcha/export_onnx_hifigan.py
```
The above command generates 3 files:
- hifigan_v1.onnx
- hifigan_v2.onnx
- hifigan_v3.onnx
**HINT**: You can download pre-exported hifigan ONNX models from
<https://github.com/k2-fsa/sherpa-onnx/releases/tag/vocoder-models>
To use the generated onnx files to generate speech from text, please run:
```bash
# First, generate ./lexicon.txt
python3 ./matcha/generate_lexicon.py
python3 ./matcha/onnx_pretrained.py \
--acoustic-model ./model-steps-4.onnx \
--vocoder ./hifigan_v2.onnx \
--tokens ./data/tokens.txt \
--lexicon ./lexicon.txt \
--input-text "在一个阳光明媚的夏天,小马、小羊和小狗它们一块儿在广阔的草地上,嬉戏玩耍,这时小猴来了,还带着它心爱的足球活蹦乱跳地跑前、跑后教小马、小羊、小狗踢足球。" \
--output-wav ./1.wav
```
```bash
soxi ./1.wav
Input File : './1.wav'
Channels : 1
Sample Rate : 22050
Precision : 16-bit
Duration : 00:00:16.37 = 360960 samples ~ 1227.76 CDDA sectors
File Size : 722k
Bit Rate : 353k
Sample Encoding: 16-bit Signed Integer PCM
```
https://github.com/user-attachments/assets/578d04bb-fee8-47e5-9984-a868dcce610e

View File

@ -0,0 +1 @@
../matcha/audio.py

View File

@ -0,0 +1,110 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the baker-zh dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
from lhotse.audio import RecordingSet
from lhotse.supervision import SupervisionSet
from icefall.utils import get_executor
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-jobs",
type=int,
default=4,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
return parser
def compute_fbank_baker_zh(num_jobs: int):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
if num_jobs < 1:
num_jobs = os.cpu_count()
logging.info(f"num_jobs: {num_jobs}")
logging.info(f"src_dir: {src_dir}")
logging.info(f"output_dir: {output_dir}")
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=22050,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
prefix = "baker_zh"
suffix = "jsonl.gz"
extractor = MatchaFbank(config)
with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts.{suffix}"
logging.info(f"Processing {cuts_filename}")
cut_set = load_manifest(src_dir / cuts_filename).resample(22050)
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_parser().parse_args()
compute_fbank_baker_zh(args.num_jobs)

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/local/compute_fbank_statistics.py

View File

@ -0,0 +1,121 @@
#!/usr/bin/env python3
import argparse
import re
from typing import List
import jieba
from lhotse import load_manifest
from pypinyin import Style, lazy_pinyin, load_phrases_dict
load_phrases_dict(
{
"行长": [["hang2"], ["zhang3"]],
"银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]],
}
)
whiter_space_re = re.compile(r"\s+")
punctuations_re = [
(re.compile(x[0], re.IGNORECASE), x[1])
for x in [
("", ","),
("", "."),
("", "!"),
("", "?"),
("", '"'),
("", '"'),
("", "'"),
("", "'"),
("", ":"),
("", ","),
("", ""),
("", ""),
]
]
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--in-file",
type=str,
required=True,
help="Input cutset.",
)
parser.add_argument(
"--out-file",
type=str,
required=True,
help="Output cutset.",
)
return parser
def normalize_white_spaces(text):
return whiter_space_re.sub(" ", text)
def normalize_punctuations(text):
for regex, replacement in punctuations_re:
text = re.sub(regex, replacement, text)
return text
def split_text(text: str) -> List[str]:
"""
Example input: '你好呀You are 一个好人。 去银行存钱How about you?'
Example output: ['你好', '', ',', 'you are', '一个', '好人', '.', '', '银行', '存钱', '?', 'how about you', '?']
"""
text = text.lower()
text = normalize_white_spaces(text)
text = normalize_punctuations(text)
ans = []
for seg in jieba.cut(text):
if seg in ",.!?:\"'":
ans.append(seg)
elif seg == " " and len(ans) > 0:
if ord("a") <= ord(ans[-1][-1]) <= ord("z"):
ans[-1] += seg
elif ord("a") <= ord(seg[0]) <= ord("z"):
if len(ans) == 0:
ans.append(seg)
continue
if ans[-1][-1] == " ":
ans[-1] += seg
continue
ans.append(seg)
else:
ans.append(seg)
ans = [s.strip() for s in ans]
return ans
def main():
args = get_parser().parse_args()
cuts = load_manifest(args.in_file)
for c in cuts:
assert len(c.supervisions) == 1, (len(c.supervisions), c.supervisions)
text = c.supervisions[0].normalized_text
text_list = split_text(text)
tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True)
c.tokens = tokens
cuts.to_file(args.out_file)
print(f"saved to {args.out_file}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../matcha/fbank.py

View File

@ -0,0 +1,85 @@
#!/usr/bin/env python3
"""
This file generates the file tokens.txt.
Usage:
python3 ./local/generate_tokens.py > data/tokens.txt
"""
import argparse
from typing import List
import jieba
from pypinyin import Style, lazy_pinyin, pinyin_dict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to to save tokens.txt.",
)
return parser
def generate_token_list() -> List[str]:
token_set = set()
word_dict = pinyin_dict.pinyin_dict
i = 0
for key in word_dict:
if not (0x4E00 <= key <= 0x9FFF):
continue
w = chr(key)
t = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0]
token_set.add(t)
no_digit = set()
for t in token_set:
if t[-1] not in "1234":
no_digit.add(t)
else:
no_digit.add(t[:-1])
no_digit.add("dei")
no_digit.add("tou")
no_digit.add("dia")
for t in no_digit:
token_set.add(t)
for i in range(1, 5):
token_set.add(f"{t}{i}")
ans = list(token_set)
ans.sort()
punctuations = list(",.!?:\"'")
ans = punctuations + ans
# use ID 0 for blank
# Use ID 1 of _ for padding
ans.insert(0, " ")
ans.insert(1, "_") #
return ans
def main():
args = get_parser().parse_args()
token_list = generate_token_list()
with open(args.tokens, "w", encoding="utf-8") as f:
for indx, token in enumerate(token_list):
f.write(f"{token} {indx}\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
We will add more checks later if needed.
Usage example:
python3 ./local/validate_manifest.py \
./data/spectrogram/baker_zh_cuts_all.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset.speech_synthesis import validate_for_tts
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)
return parser.parse_args()
def main():
args = get_args()
manifest = args.manifest
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest_lazy(manifest)
assert isinstance(cut_set, CutSet), type(cut_set)
validate_for_tts(cut_set)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/audio.py

View File

@ -0,0 +1,207 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
"""
This script exports a Matcha-TTS model to ONNX.
Note that the model outputs fbank. You need to use a vocoder to convert
it to audio. See also ./export_onnx_hifigan.py
python3 ./matcha/export_onnx.py \
--exp-dir ./matcha/exp-1 \
--epoch 2000 \
--tokens ./data/tokens.txt \
--cmvn ./data/fbank/cmvn.json
"""
import argparse
import json
import logging
from pathlib import Path
from typing import Any, Dict
import onnx
import torch
from tokenizer import Tokenizer
from train import get_model, get_params
from icefall.checkpoint import load_checkpoint
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=2000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp-new-3",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=Path,
default="data/tokens.txt",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class ModelWrapper(torch.nn.Module):
def __init__(self, model, num_steps: int = 5):
super().__init__()
self.model = model
self.num_steps = num_steps
def forward(
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
noise_scale: torch.Tensor,
length_scale: torch.Tensor,
) -> torch.Tensor:
"""
Args: :
x: (batch_size, num_tokens), torch.int64
x_lengths: (batch_size,), torch.int64
noise_scale: (1,), torch.float32
length_scale (1,), torch.float32
Returns:
audio: (batch_size, num_samples)
"""
mel = self.model.synthesise(
x=x,
x_lengths=x_lengths,
n_timesteps=self.num_steps,
temperature=noise_scale,
length_scale=length_scale,
)["mel"]
# mel: (batch_size, feat_dim, num_frames)
return mel
@torch.inference_mode()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.pad_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
for num_steps in [2, 3, 4, 5, 6]:
logging.info(f"num_steps: {num_steps}")
wrapper = ModelWrapper(model, num_steps=num_steps)
wrapper.eval()
# Use a large value so the rotary position embedding in the text
# encoder has a large initial length
x = torch.ones(1, 1000, dtype=torch.int64)
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
noise_scale = torch.tensor([1.0])
length_scale = torch.tensor([1.0])
opset_version = 14
filename = f"model-steps-{num_steps}.onnx"
torch.onnx.export(
wrapper,
(x, x_lengths, noise_scale, length_scale),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "noise_scale", "length_scale"],
output_names=["mel"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_length": {0: "N"},
"mel": {0: "N", 2: "L"},
},
)
meta_data = {
"model_type": "matcha-tts",
"language": "Chinese",
"has_espeak": 0,
"n_speakers": 1,
"jieba": 1,
"sample_rate": 22050,
"version": 1,
"pad_id": params.pad_id,
"model_author": "icefall",
"maintainer": "k2-fsa",
"dataset": "baker-zh",
"use_eos_bos": 0,
"dataset_url": "https://www.data-baker.com/open_source.html",
"dataset_comment": "The dataset is for non-commercial use only.",
"num_ode_steps": num_steps,
}
add_meta_data(filename=filename, meta_data=meta_data)
print(meta_data)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/export_onnx_hifigan.py

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/fbank.py

View File

@ -0,0 +1,42 @@
#!/usr/bin/env python3
import jieba
from pypinyin import Style, lazy_pinyin, load_phrases_dict, phrases_dict, pinyin_dict
from tokenizer import Tokenizer
load_phrases_dict(
{
"行长": [["hang2"], ["zhang3"]],
"银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]],
}
)
def main():
filename = "lexicon.txt"
tokens = "./data/tokens.txt"
tokenizer = Tokenizer(tokens)
word_dict = pinyin_dict.pinyin_dict
phrases = phrases_dict.phrases_dict
i = 0
with open(filename, "w", encoding="utf-8") as f:
for key in word_dict:
if not (0x4E00 <= key <= 0x9FFF):
continue
w = chr(key)
tokens = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0]
f.write(f"{w} {tokens}\n")
for key in phrases:
tokens = lazy_pinyin(key, style=Style.TONE3, tone_sandhi=True)
tokens = " ".join(tokens)
f.write(f"{key} {tokens}\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/hifigan

342
egs/baker_zh/TTS/matcha/infer.py Executable file
View File

@ -0,0 +1,342 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
"""
python3 ./matcha/infer.py \
--epoch 2000 \
--exp-dir ./matcha/exp-1 \
--vocoder ./generator_v2 \
--tokens ./data/tokens.txt \
--cmvn ./data/fbank/cmvn.json \
--input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \
--output-wav ./generated.wav
"""
import argparse
import datetime as dt
import json
import logging
from pathlib import Path
import soundfile as sf
import torch
import torch.nn as nn
from hifigan.config import v1, v2, v3
from hifigan.denoiser import Denoiser
from hifigan.models import Generator as HiFiGAN
from local.convert_text_to_tokens import split_text
from pypinyin import Style, lazy_pinyin
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import BakerZhTtsDataModule
from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=4000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--vocoder",
type=Path,
default="./generator_v1",
help="Path to the vocoder",
)
parser.add_argument(
"--tokens",
type=Path,
default="data/tokens.txt",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
# The following arguments are used for inference on single text
parser.add_argument(
"--input-text",
type=str,
required=False,
help="The text to generate speech for",
)
parser.add_argument(
"--output-wav",
type=str,
required=False,
help="The filename of the wave to save the generated speech",
)
parser.add_argument(
"--sampling-rate",
type=int,
default=22050,
help="The sampling rate of the generated speech (default: 22050 for baker_zh)",
)
return parser
def load_vocoder(checkpoint_path: Path) -> nn.Module:
checkpoint_path = str(checkpoint_path)
if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
elif checkpoint_path.endswith("v2"):
h = AttributeDict(v2)
elif checkpoint_path.endswith("v3"):
h = AttributeDict(v3)
else:
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan
def to_waveform(
mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module
) -> torch.Tensor:
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.squeeze()
def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict:
text = split_text(text)
tokens = lazy_pinyin(text, style=Style.TONE3, tone_sandhi=True)
x = tokenizer.texts_to_token_ids([tokens])
x = torch.tensor(x, dtype=torch.long, device=device)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
def synthesize(
model: nn.Module,
tokenizer: Tokenizer,
n_timesteps: int,
text: str,
length_scale: float,
temperature: float,
device: str = "cpu",
spks=None,
) -> dict:
text_processed = process_text(text=text, tokenizer=tokenizer, device=device)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=n_timesteps,
temperature=temperature,
spks=spks,
length_scale=length_scale,
)
# merge everything to one dict
output.update({"start_t": start_t, **text_processed})
return output
def infer_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
vocoder: nn.Module,
denoiser: nn.Module,
tokenizer: Tokenizer,
) -> None:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
tokenizer:
Used to convert text to phonemes.
"""
device = next(model.parameters()).device
num_cuts = 0
log_interval = 5
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["tokens"])
texts = [c.supervisions[0].normalized_text for c in batch["cut"]]
audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
for i in range(batch_size):
output = synthesize(
model=model,
tokenizer=tokenizer,
n_timesteps=params.n_timesteps,
text=texts[i],
length_scale=params.length_scale,
temperature=params.temperature,
device=device,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
sf.write(
file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav",
data=output["waveform"],
samplerate=params.data_args.sampling_rate,
subtype="PCM_16",
)
sf.write(
file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav",
data=audio[i].numpy(),
samplerate=params.data_args.sampling_rate,
subtype="PCM_16",
)
num_cuts += batch_size
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
@torch.inference_mode()
def main():
parser = get_parser()
BakerZhTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}"
params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
params.save_wav_dir.mkdir(parents=True, exist_ok=True)
setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Infer started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
# Number of ODE Solver steps
params.n_timesteps = 2
# Changes to the speaking rate
params.length_scale = 1.0
# Sampling temperature
params.temperature = 0.667
logging.info(params)
logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
# we need cut ids to organize tts results.
args.return_cuts = True
baker_zh = BakerZhTtsDataModule(args)
test_cuts = baker_zh.test_cuts()
test_dl = baker_zh.test_dataloaders(test_cuts)
if not Path(params.vocoder).is_file():
raise ValueError(f"{params.vocoder} does not exist")
vocoder = load_vocoder(params.vocoder)
vocoder.to(device)
denoiser = Denoiser(vocoder, mode="zeros")
denoiser.to(device)
if params.input_text is not None and params.output_wav is not None:
logging.info("Synthesizing a single text")
output = synthesize(
model=model,
tokenizer=tokenizer,
n_timesteps=params.n_timesteps,
text=params.input_text,
length_scale=params.length_scale,
temperature=params.temperature,
device=device,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
sf.write(
file=params.output_wav,
data=output["waveform"],
samplerate=params.sampling_rate,
subtype="PCM_16",
)
else:
logging.info("Decoding the test set")
infer_dataset(
dl=test_dl,
params=params,
model=model,
vocoder=vocoder,
denoiser=denoiser,
tokenizer=tokenizer,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/model.py

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/models

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/monotonic_align

View File

@ -0,0 +1,316 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
"""
python3 ./matcha/onnx_pretrained.py \
--acoustic-model ./model-steps-4.onnx \
--vocoder ./hifigan_v2.onnx \
--tokens ./data/tokens.txt \
--lexicon ./lexicon.txt \
--input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \
--output-wav ./b.wav
"""
import argparse
import datetime as dt
import logging
import re
from typing import Dict, List
import jieba
import onnxruntime as ort
import soundfile as sf
import torch
from infer import load_vocoder
from utils import intersperse
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--acoustic-model",
type=str,
required=True,
help="Path to the acoustic model",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to the tokens.txt",
)
parser.add_argument(
"--lexicon",
type=str,
required=True,
help="Path to the lexicon.txt",
)
parser.add_argument(
"--vocoder",
type=str,
required=True,
help="Path to the vocoder",
)
parser.add_argument(
"--input-text",
type=str,
required=True,
help="The text to generate speech for",
)
parser.add_argument(
"--output-wav",
type=str,
required=True,
help="The filename of the wave to save the generated speech",
)
return parser
class OnnxHifiGANModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
for i in self.model.get_inputs():
print(i)
print("-----")
for i in self.model.get_outputs():
print(i)
def __call__(self, x: torch.tensor):
assert x.ndim == 3, x.shape
assert x.shape[0] == 1, x.shape
audio = self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: x.numpy(),
},
)[0]
# audio: (batch_size, num_samples)
return torch.from_numpy(audio)
class OnnxModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 2
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
metadata = self.model.get_modelmeta().custom_metadata_map
self.sample_rate = int(metadata["sample_rate"])
for i in self.model.get_inputs():
print(i)
print("-----")
for i in self.model.get_outputs():
print(i)
def __call__(self, x: torch.tensor):
assert x.ndim == 2, x.shape
assert x.shape[0] == 1, x.shape
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
print("x_lengths", x_lengths)
print("x", x.shape)
noise_scale = torch.tensor([1.0], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
mel = self.model.run(
[self.model.get_outputs()[0].name],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lengths.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: length_scale.numpy(),
},
)[0]
# mel: (batch_size, feat_dim, num_frames)
return torch.from_numpy(mel)
def read_tokens(filename: str) -> Dict[str, int]:
token2id = dict()
with open(filename, encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
if len(info) == 1:
# case of space
token = " "
idx = int(info[0])
else:
token, idx = info[0], int(info[1])
assert token not in token2id, token
token2id[token] = idx
return token2id
def read_lexicon(filename: str) -> Dict[str, List[str]]:
word2token = dict()
with open(filename, encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
w = info[0]
tokens = info[1:]
word2token[w] = tokens
return word2token
def convert_word_to_tokens(word2tokens: Dict[str, List[str]], word: str) -> List[str]:
if word in word2tokens:
return word2tokens[word]
if len(word) == 1:
return []
ans = []
for w in word:
t = convert_word_to_tokens(word2tokens, w)
ans.extend(t)
return ans
def normalize_text(text):
whiter_space_re = re.compile(r"\s+")
punctuations_re = [
(re.compile(x[0], re.IGNORECASE), x[1])
for x in [
("", ","),
("", "."),
("", "!"),
("", "?"),
("", '"'),
("", '"'),
("", "'"),
("", "'"),
("", ":"),
("", ","),
]
]
for regex, replacement in punctuations_re:
text = re.sub(regex, replacement, text)
return text
@torch.no_grad()
def main():
params = get_parser().parse_args()
logging.info(vars(params))
token2id = read_tokens(params.tokens)
word2tokens = read_lexicon(params.lexicon)
text = normalize_text(params.input_text)
seg = jieba.cut(text)
tokens = []
for s in seg:
if s in token2id:
tokens.append(s)
continue
t = convert_word_to_tokens(word2tokens, s)
if t:
tokens.extend(t)
model = OnnxModel(params.acoustic_model)
vocoder = OnnxHifiGANModel(params.vocoder)
x = []
for t in tokens:
if t in token2id:
x.append(token2id[t])
x = intersperse(x, item=token2id["_"])
x = torch.tensor(x, dtype=torch.int64).unsqueeze(0)
start_t = dt.datetime.now()
mel = model(x)
end_t = dt.datetime.now()
start_t2 = dt.datetime.now()
audio = vocoder(mel)
end_t2 = dt.datetime.now()
print("audio", audio.shape) # (1, 1, num_samples)
audio = audio.squeeze()
sample_rate = model.sample_rate
t = (end_t - start_t).total_seconds()
t2 = (end_t2 - start_t2).total_seconds()
rtf_am = t * sample_rate / audio.shape[-1]
rtf_vocoder = t2 * sample_rate / audio.shape[-1]
print("RTF for acoustic model ", rtf_am)
print("RTF for vocoder", rtf_vocoder)
# skip denoiser
sf.write(params.output_wav, audio, sample_rate, "PCM_16")
logging.info(f"Saved to {params.output_wav}")
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()
"""
|HifiGAN |RTF |#Parameters (M)|
|----------|-----|---------------|
|v1 |0.818| 13.926 |
|v2 |0.101| 0.925 |
|v3 |0.118| 1.462 |
|Num steps|Acoustic Model RTF|
|---------|------------------|
| 2 | 0.039 |
| 3 | 0.047 |
| 4 | 0.071 |
| 5 | 0.076 |
| 6 | 0.103 |
"""

View File

@ -0,0 +1,119 @@
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import logging
from typing import Dict, List
import tacotron_cleaner.cleaners
try:
from piper_phonemize import phonemize_espeak
except Exception as ex:
raise RuntimeError(
f"{ex}\nPlease run\n"
"pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html"
)
from utils import intersperse
# This tokenizer supports both English and Chinese.
# We assume you have used
# ../local/convert_text_to_tokens.py
# to process your text
class Tokenizer(object):
def __init__(self, tokens: str):
"""
Args:
tokens: the file that maps tokens to ids
"""
# Parse token file
self.token2id: Dict[str, int] = {}
with open(tokens, "r", encoding="utf-8") as f:
for line in f.readlines():
info = line.rstrip().split()
if len(info) == 1:
# case of space
token = " "
id = int(info[0])
else:
token, id = info[0], int(info[1])
assert token not in self.token2id, token
self.token2id[token] = id
# Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md
self.pad_id = self.token2id["_"] # padding
self.space_id = self.token2id[" "] # word separator (whitespace)
self.vocab_size = len(self.token2id)
def texts_to_token_ids(
self,
sentence_list: List[List[str]],
intersperse_blank: bool = True,
lang: str = "en-us",
) -> List[List[int]]:
"""
Args:
sentence_list:
A list of sentences.
intersperse_blank:
Whether to intersperse blanks in the token sequence.
lang:
Language argument passed to phonemize_espeak().
Returns:
Return a list of token id list [utterance][token_id]
"""
token_ids_list = []
for sentence in sentence_list:
tokens_list = []
for word in sentence:
if word in self.token2id:
tokens_list.append(word)
continue
tmp_tokens_list = phonemize_espeak(word, lang)
for t in tmp_tokens_list:
tokens_list.extend(t)
token_ids = []
for t in tokens_list:
if t not in self.token2id:
logging.warning(f"Skip OOV {t} {sentence}")
continue
if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id:
continue
token_ids.append(self.token2id[t])
if intersperse_blank:
token_ids = intersperse(token_ids, self.pad_id)
token_ids_list.append(token_ids)
return token_ids_list
def test_tokenizer():
import jieba
from pypinyin import Style, lazy_pinyin
tokenizer = Tokenizer("data/tokens.txt")
text1 = "今天is Monday, tomorrow is 星期二"
text2 = "你好吗? 我很好, how about you?"
text1 = list(jieba.cut(text1))
text2 = list(jieba.cut(text2))
tokens1 = lazy_pinyin(text1, style=Style.TONE3, tone_sandhi=True)
tokens2 = lazy_pinyin(text2, style=Style.TONE3, tone_sandhi=True)
print(tokens1)
print(tokens2)
ids = tokenizer.texts_to_token_ids([tokens1, tokens2])
print(ids)
if __name__ == "__main__":
test_tokenizer()

717
egs/baker_zh/TTS/matcha/train.py Executable file
View File

@ -0,0 +1,717 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import json
import logging
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Union
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lhotse.utils import fix_random_seed
from model import fix_len_compatibility
from models.matcha_tts import MatchaTTS
from tokenizer import Tokenizer
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import BakerZhTtsDataModule
from utils import MetricsTracker
from icefall.checkpoint import load_checkpoint, save_checkpoint
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12335,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=1000,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--save-every-n",
type=int,
default=10,
help="""Save checkpoint after processing this number of epochs"
periodically. We save checkpoint to exp-dir/ whenever
params.cur_epoch % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
Since it will take around 1000 epochs, we suggest using a large
save_every_n to save disk space.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser
def get_data_statistics():
return AttributeDict(
{
"mel_mean": 0,
"mel_std": 1,
}
)
def _get_data_params() -> AttributeDict:
params = AttributeDict(
{
"name": "baker-zh",
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
# "batch_size": 64,
# "num_workers": 1,
# "pin_memory": False,
"cleaners": ["english_cleaners2"],
"add_blank": True,
"n_spks": 1,
"n_fft": 1024,
"n_feats": 80,
"sampling_rate": 22050,
"hop_length": 256,
"win_length": 1024,
"f_min": 0,
"f_max": 8000,
"seed": 1234,
"load_durations": False,
"data_statistics": get_data_statistics(),
}
)
return params
def _get_model_params() -> AttributeDict:
n_feats = 80
filter_channels_dp = 256
encoder_params_p_dropout = 0.1
params = AttributeDict(
{
"n_spks": 1, # for baker-zh.
"spk_emb_dim": 64,
"n_feats": n_feats,
"out_size": None, # or use 172
"prior_loss": True,
"use_precomputed_durations": False,
"data_statistics": get_data_statistics(),
"encoder": AttributeDict(
{
"encoder_type": "RoPE Encoder", # not used
"encoder_params": AttributeDict(
{
"n_feats": n_feats,
"n_channels": 192,
"filter_channels": 768,
"filter_channels_dp": filter_channels_dp,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
"spk_emb_dim": 64,
"n_spks": 1,
"prenet": True,
}
),
"duration_predictor_params": AttributeDict(
{
"filter_channels_dp": filter_channels_dp,
"kernel_size": 3,
"p_dropout": encoder_params_p_dropout,
}
),
}
),
"decoder": AttributeDict(
{
"channels": [256, 256],
"dropout": 0.05,
"attention_head_dim": 64,
"n_blocks": 1,
"num_mid_blocks": 2,
"num_heads": 2,
"act_fn": "snakebeta",
}
),
"cfm": AttributeDict(
{
"name": "CFM",
"solver": "euler",
"sigma_min": 1e-4,
}
),
"optimizer": AttributeDict(
{
"lr": 1e-4,
"weight_decay": 0.0,
}
),
}
)
return params
def get_params():
params = AttributeDict(
{
"model_args": _get_model_params(),
"data_args": _get_data_params(),
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": -1, # 0
"log_interval": 10,
"valid_interval": 1500,
"env_info": get_env_info(),
}
)
return params
def get_model(params):
m = MatchaTTS(**params.model_args)
return m
def load_checkpoint_if_available(
params: AttributeDict, model: nn.Module
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
Returns:
Return a dict containing previously saved training info.
"""
if params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
saved_params = load_checkpoint(filename, model=model)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params):
"""Parse batch data"""
mel_mean = params.data_args.data_statistics.mel_mean
mel_std_inv = 1 / params.data_args.data_statistics.mel_std
for i in range(batch["features"].shape[0]):
n = batch["features_lens"][i]
batch["features"][i : i + 1, :n, :] = (
batch["features"][i : i + 1, :n, :] - mel_mean
) * mel_std_inv
batch["features"][i : i + 1, n:, :] = 0
audio = batch["audio"].to(device)
features = batch["features"].to(device)
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
tokens = tokenizer.texts_to_token_ids(tokens, intersperse_blank=True)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
max_feature_length = fix_len_compatibility(features.shape[1])
if max_feature_length > features.shape[1]:
pad = max_feature_length - features.shape[1]
features = torch.nn.functional.pad(features, (0, 0, 0, pad))
# features_lens[features_lens.argmax()] += pad
return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long()
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
rank: int = 0,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
# used to summary the stats over iterations
tot_loss = MetricsTracker()
with torch.no_grad():
for batch_idx, batch in enumerate(valid_dl):
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device, params)
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
batch_size = len(batch["tokens"])
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
s = 0
for key, value in losses.items():
v = value.detach().item()
loss_info[key] = v * batch_size
s += v * batch_size
loss_info["tot_loss"] = s
# summary stats
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(device)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
tokenizer: Tokenizer,
optimizer: Optimizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
"""
model.train()
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
# used to track the stats over iterations in one epoch
tot_loss = MetricsTracker()
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint(
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
model=model,
params=params,
optimizer=optimizer,
scaler=scaler,
rank=0,
)
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
# audio: (N, T), float32
# features: (N, T, C), float32
# audio_lens, (N,), int32
# features_lens, (N,), int32
# tokens: List[List[str]], len(tokens) == N
batch_size = len(batch["tokens"])
(
audio,
audio_lens,
features,
features_lens,
tokens,
tokens_lens,
) = prepare_input(batch, tokenizer, device, params)
try:
with autocast(enabled=params.use_fp16):
losses = get_losses(
{
"x": tokens,
"x_lengths": tokens_lens,
"y": features.permute(0, 2, 1),
"y_lengths": features_lens,
"spks": None, # should change it for multi-speakers
"durations": None,
}
)
loss = sum(losses.values())
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
s = 0
for key, value in losses.items():
v = value.detach().item()
loss_info[key] = v * batch_size
s += v * batch_size
loss_info["tot_loss"] = s
tot_loss = tot_loss + loss_info
except: # noqa
save_bad_model()
raise
if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it.
# The _growth_interval of the grad scaler is configurable,
# but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise RuntimeError(
f"grad_scale is too small, exiting: {cur_grad_scale}"
)
if params.batch_idx_train % params.log_interval == 0:
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"global_batch_idx: {params.batch_idx_train}, "
f"batch size: {batch_size}, "
f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if params.batch_idx_train % params.valid_interval == 1:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
valid_dl=valid_dl,
world_size=world_size,
rank=rank,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(
"Maximum memory allocated so far is "
f"{torch.cuda.max_memory_allocated()//1000000}MB"
)
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
tokenizer = Tokenizer(params.tokens)
params.pad_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)
print(params)
logging.info("About to create model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of parameters: {num_param}")
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
logging.info("About to create datamodule")
baker_zh = BakerZhTtsDataModule(args)
train_cuts = baker_zh.train_cuts()
train_dl = baker_zh.train_dataloaders(train_cuts)
valid_cuts = baker_zh.valid_cuts()
valid_dl = baker_zh.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
fix_random_seed(params.seed + epoch - 1)
if "sampler" in train_dl:
train_dl.sampler.set_epoch(epoch - 1)
params.cur_epoch = epoch
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
train_one_epoch(
params=params,
model=model,
tokenizer=tokenizer,
optimizer=optimizer,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint(
filename=filename,
params=params,
model=model,
optimizer=optimizer,
scaler=scaler,
rank=rank,
)
if rank == 0:
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def main():
parser = get_parser()
BakerZhTtsDataModule.add_arguments(parser)
args = parser.parse_args()
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -0,0 +1,340 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
SpeechSynthesisDataset,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class BakerZhTtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
train = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=True,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
else:
validate = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=True,
pin_memory=True,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
sampling_rate = 22050
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=sampling_rate,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
return_cuts=self.args.return_cuts,
)
else:
test = SpeechSynthesisDataset(
return_text=False,
return_tokens=True,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(
self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/utils.py

151
egs/baker_zh/TTS/prepare.sh Executable file
View File

@ -0,0 +1,151 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
stage=-1
stop_stage=100
dl_dir=$PWD/download
mkdir -p $dl_dir
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: build monotonic_align lib (used by ./matcha)"
for recipe in matcha; do
if [ ! -d $recipe/monotonic_align/build ]; then
cd $recipe/monotonic_align
python3 setup.py build_ext --inplace
cd ../../
else
log "monotonic_align lib for $recipe already built"
fi
done
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# The directory $dl_dir/BANSYP contains the following 3 directories
# ls -lh $dl_dir/BZNSYP/
# total 0
# drwxr-xr-x 10002 kuangfangjun root 0 Jan 4 2019 PhoneLabeling
# drwxr-xr-x 3 kuangfangjun root 0 Jan 31 2019 ProsodyLabeling
# drwxr-xr-x 10003 kuangfangjun root 0 Aug 26 17:45 Wave
# If you have trouble accessing huggingface.co, please use
#
# cd $dl_dir
# wget https://huggingface.co/openspeech/BZNSYP/resolve/main/BZNSYP.tar.bz2
# tar xf BZNSYP.tar.bz2
# cd ..
# If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink
#
# ln -sfv /path/to/BZNSYP $dl_dir/BZNSYP
#
if [ ! -d $dl_dir/BZNSYP/Wave ]; then
lhotse download baker-zh $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare baker-zh manifest"
# We assume that you have downloaded the baker corpus
# to $dl_dir/BZNSYP
mkdir -p data/manifests
if [ ! -e data/manifests/.baker-zh.done ]; then
lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests
touch data/manifests/.baker-zh.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Generate tokens.txt"
if [ ! -e data/tokens.txt ]; then
python3 ./local/generate_tokens.py --tokens data/tokens.txt
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Generate raw cutset"
if [ ! -e data/manifests/baker_zh_cuts_raw.jsonl.gz ]; then
lhotse cut simple \
-r ./data/manifests/baker_zh_recordings_all.jsonl.gz \
-s ./data/manifests/baker_zh_supervisions_all.jsonl.gz \
./data/manifests/baker_zh_cuts_raw.jsonl.gz
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Convert text to tokens"
if [ ! -e data/manifests/baker_zh_cuts.jsonl.gz ]; then
python3 ./local/convert_text_to_tokens.py \
--in-file ./data/manifests/baker_zh_cuts_raw.jsonl.gz \
--out-file ./data/manifests/baker_zh_cuts.jsonl.gz
fi
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate fbank (used by ./matcha)"
mkdir -p data/fbank
if [ ! -e data/fbank/.baker-zh.done ]; then
./local/compute_fbank_baker_zh.py
touch data/fbank/.baker-zh.done
fi
if [ ! -e data/fbank/.baker-zh-validated.done ]; then
log "Validating data/fbank for baker-zh (used by ./matcha)"
python3 ./local/validate_manifest.py \
data/fbank/baker_zh_cuts.jsonl.gz
touch data/fbank/.baker-zh-validated.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Split the baker-zh cuts into train, valid and test sets (used by ./matcha)"
if [ ! -e data/fbank/.baker_zh_split.done ]; then
lhotse subset --last 600 \
data/fbank/baker_zh_cuts.jsonl.gz \
data/fbank/baker_zh_cuts_validtest.jsonl.gz
lhotse subset --first 100 \
data/fbank/baker_zh_cuts_validtest.jsonl.gz \
data/fbank/baker_zh_cuts_valid.jsonl.gz
lhotse subset --last 500 \
data/fbank/baker_zh_cuts_validtest.jsonl.gz \
data/fbank/baker_zh_cuts_test.jsonl.gz
rm data/fbank/baker_zh_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/fbank/baker_zh_cuts.jsonl.gz | wc -l) - 600 ))
lhotse subset --first $n \
data/fbank/baker_zh_cuts.jsonl.gz \
data/fbank/baker_zh_cuts_train.jsonl.gz
touch data/fbank/.baker_zh_split.done
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 6: Compute fbank mean and std (used by ./matcha)"
if [ ! -f ./data/fbank/cmvn.json ]; then
./local/compute_fbank_statistics.py ./data/fbank/baker_zh_cuts_train.jsonl.gz ./data/fbank/cmvn.json
fi
fi

1
egs/baker_zh/TTS/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -166,7 +166,7 @@ To export the checkpoint to onnx:
--tokens ./data/tokens.txt
```
The above command generate the following files:
The above command generates the following files:
- model-steps-2.onnx
- model-steps-3.onnx

View File

@ -93,14 +93,14 @@ class ModelWrapper(torch.nn.Module):
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
temperature: torch.Tensor,
noise_scale: torch.Tensor,
length_scale: torch.Tensor,
) -> torch.Tensor:
"""
Args: :
x: (batch_size, num_tokens), torch.int64
x_lengths: (batch_size,), torch.int64
temperature: (1,), torch.float32
noise_scale: (1,), torch.float32
length_scale (1,), torch.float32
Returns:
audio: (batch_size, num_samples)
@ -110,7 +110,7 @@ class ModelWrapper(torch.nn.Module):
x=x,
x_lengths=x_lengths,
n_timesteps=self.num_steps,
temperature=temperature,
temperature=noise_scale,
length_scale=length_scale,
)["mel"]
# mel: (batch_size, feat_dim, num_frames)
@ -127,7 +127,6 @@ def main():
params.update(vars(args))
tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size
@ -153,14 +152,14 @@ def main():
# encoder has a large initial length
x = torch.ones(1, 1000, dtype=torch.int64)
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
temperature = torch.tensor([1.0])
noise_scale = torch.tensor([1.0])
length_scale = torch.tensor([1.0])
opset_version = 14
filename = f"model-steps-{num_steps}.onnx"
torch.onnx.export(
wrapper,
(x, x_lengths, temperature, length_scale),
(x, x_lengths, noise_scale, length_scale),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "noise_scale", "length_scale"],

View File

@ -132,7 +132,7 @@ class OnnxModel:
print("x_lengths", x_lengths)
print("x", x.shape)
temperature = torch.tensor([1.0], dtype=torch.float32)
noise_scale = torch.tensor([1.0], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
mel = self.model.run(
@ -140,7 +140,7 @@ class OnnxModel:
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: x_lengths.numpy(),
self.model.get_inputs()[2].name: temperature.numpy(),
self.model.get_inputs()[2].name: noise_scale.numpy(),
self.model.get_inputs()[3].name: length_scale.numpy(),
},
)[0]