mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add Matcha-TTS (#1773)
This commit is contained in:
parent
7e9eea6dc3
commit
516b4869b3
120
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
Executable file
120
.github/scripts/ljspeech/TTS/run-matcha.sh
vendored
Executable file
@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -ex
|
||||
|
||||
apt-get update
|
||||
apt-get install -y sox
|
||||
|
||||
python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html
|
||||
python3 -m pip install espnet_tts_frontend
|
||||
python3 -m pip install numba conformer==0.3.2 diffusers librosa
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
cd egs/ljspeech/TTS
|
||||
|
||||
sed -i.bak s/600/8/g ./prepare.sh
|
||||
sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh
|
||||
sed -i.bak s/500/5/g ./prepare.sh
|
||||
git diff
|
||||
|
||||
function prepare_data() {
|
||||
# We have created a subset of the data for testing
|
||||
#
|
||||
mkdir -p download
|
||||
pushd download
|
||||
wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
|
||||
tar xvf LJSpeech-1.1.tar.bz2
|
||||
popd
|
||||
|
||||
./prepare.sh
|
||||
tree .
|
||||
}
|
||||
|
||||
function train() {
|
||||
pushd ./matcha
|
||||
sed -i.bak s/1500/3/g ./train.py
|
||||
git diff .
|
||||
popd
|
||||
|
||||
./matcha/train.py \
|
||||
--exp-dir matcha/exp \
|
||||
--num-epochs 1 \
|
||||
--save-every-n 1 \
|
||||
--num-buckets 2 \
|
||||
--tokens data/tokens.txt \
|
||||
--max-duration 20
|
||||
|
||||
ls -lh matcha/exp
|
||||
}
|
||||
|
||||
function infer() {
|
||||
|
||||
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
||||
|
||||
./matcha/inference.py \
|
||||
--epoch 1 \
|
||||
--exp-dir ./matcha/exp \
|
||||
--tokens data/tokens.txt \
|
||||
--vocoder ./generator_v1 \
|
||||
--input-text "how are you doing?" \
|
||||
--output-wav ./generated.wav
|
||||
|
||||
ls -lh *.wav
|
||||
soxi ./generated.wav
|
||||
rm -v ./generated.wav
|
||||
rm -v generator_v1
|
||||
}
|
||||
|
||||
function export_onnx() {
|
||||
pushd matcha/exp
|
||||
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/epoch-4000.pt
|
||||
popd
|
||||
|
||||
pushd data/fbank
|
||||
rm -v *.json
|
||||
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/data/cmvn.json
|
||||
popd
|
||||
|
||||
./matcha/export_onnx.py \
|
||||
--exp-dir ./matcha/exp \
|
||||
--epoch 4000 \
|
||||
--tokens ./data/tokens.txt \
|
||||
--cmvn ./data/fbank/cmvn.json
|
||||
|
||||
ls -lh *.onnx
|
||||
|
||||
if false; then
|
||||
# THe CI machine does not have enough memory to run it
|
||||
#
|
||||
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
||||
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2
|
||||
curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3
|
||||
python3 ./matcha/export_onnx_hifigan.py
|
||||
else
|
||||
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx
|
||||
fi
|
||||
|
||||
ls -lh *.onnx
|
||||
|
||||
python3 ./matcha/onnx_pretrained.py \
|
||||
--acoustic-model ./model-steps-6.onnx \
|
||||
--vocoder ./hifigan_v1.onnx \
|
||||
--tokens ./data/tokens.txt \
|
||||
--input-text "how are you doing?" \
|
||||
--output-wav /icefall/generated-matcha-tts-steps-6-v1.wav
|
||||
|
||||
ls -lh /icefall/*.wav
|
||||
soxi /icefall/generated-matcha-tts-steps-6-v1.wav
|
||||
}
|
||||
|
||||
prepare_data
|
||||
train
|
||||
infer
|
||||
export_onnx
|
||||
|
||||
rm -rfv generator_v* matcha/exp
|
2
.github/scripts/ljspeech/TTS/run.sh
vendored
2
.github/scripts/ljspeech/TTS/run.sh
vendored
@ -22,7 +22,7 @@ git diff
|
||||
function prepare_data() {
|
||||
# We have created a subset of the data for testing
|
||||
#
|
||||
mkdir download
|
||||
mkdir -p download
|
||||
pushd download
|
||||
wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2
|
||||
tar xvf LJSpeech-1.1.tar.bz2
|
||||
|
6
.github/workflows/audioset.yml
vendored
6
.github/workflows/audioset.yml
vendored
@ -83,7 +83,7 @@ jobs:
|
||||
ls -lh ./model-onnx/*
|
||||
|
||||
- name: Upload model to huggingface
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
@ -116,7 +116,7 @@ jobs:
|
||||
rm -rf huggingface
|
||||
|
||||
- name: Prepare for release
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
shell: bash
|
||||
run: |
|
||||
d=sherpa-onnx-zipformer-audio-tagging-2024-04-09
|
||||
@ -125,7 +125,7 @@ jobs:
|
||||
ls -lh
|
||||
|
||||
- name: Release exported onnx models
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
|
12
.github/workflows/ljspeech.yml
vendored
12
.github/workflows/ljspeech.yml
vendored
@ -70,6 +70,7 @@ jobs:
|
||||
cd /icefall
|
||||
git config --global --add safe.directory /icefall
|
||||
|
||||
.github/scripts/ljspeech/TTS/run-matcha.sh
|
||||
.github/scripts/ljspeech/TTS/run.sh
|
||||
|
||||
- name: display files
|
||||
@ -78,19 +79,13 @@ jobs:
|
||||
ls -lh
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0'
|
||||
with:
|
||||
name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }}
|
||||
path: ./*.wav
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0'
|
||||
with:
|
||||
name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}
|
||||
path: ./*.wav
|
||||
|
||||
- name: Release exported onnx models
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push'
|
||||
if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push'
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
@ -99,4 +94,3 @@ jobs:
|
||||
repo_name: k2-fsa/sherpa-onnx
|
||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||
tag: tts-models
|
||||
|
||||
|
7
egs/ljspeech/TTS/.gitignore
vendored
Normal file
7
egs/ljspeech/TTS/.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
build
|
||||
core.c
|
||||
*.so
|
||||
my-output*
|
||||
*.wav
|
||||
*.onnx
|
||||
generator_v*
|
@ -101,3 +101,121 @@ export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
|
||||
# (Note it is killed after `epoch-820.pt`)
|
||||
```
|
||||
# matcha
|
||||
|
||||
[./matcha](./matcha) contains the code for training [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS)
|
||||
|
||||
This recipe provides a Matcha-TTS model trained on the LJSpeech dataset.
|
||||
|
||||
Checkpoints and training logs can be found [here](https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28).
|
||||
The pull-request for this recipe can be found at <https://github.com/k2-fsa/icefall/pull/1773>
|
||||
|
||||
The training command is given below:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
|
||||
python3 ./matcha/train.py \
|
||||
--exp-dir ./matcha/exp-new-3/ \
|
||||
--num-workers 4 \
|
||||
--world-size 4 \
|
||||
--num-epochs 4000 \
|
||||
--max-duration 1000 \
|
||||
--bucketing-sampler 1 \
|
||||
--start-epoch 1
|
||||
```
|
||||
|
||||
To inference, use:
|
||||
|
||||
```bash
|
||||
# Download Hifigan vocoder. We use Hifigan v1 below. You can select from v1, v2, or v3
|
||||
|
||||
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
||||
|
||||
./matcha/inference \
|
||||
--exp-dir ./matcha/exp-new-3 \
|
||||
--epoch 4000 \
|
||||
--tokens ./data/tokens.txt \
|
||||
--vocoder ./generator_v1 \
|
||||
--input-text "how are you doing?"
|
||||
--output-wav ./generated.wav
|
||||
```
|
||||
|
||||
```bash
|
||||
soxi ./generated.wav
|
||||
```
|
||||
prints:
|
||||
```
|
||||
Input File : './generated.wav'
|
||||
Channels : 1
|
||||
Sample Rate : 22050
|
||||
Precision : 16-bit
|
||||
Duration : 00:00:01.29 = 28416 samples ~ 96.6531 CDDA sectors
|
||||
File Size : 56.9k
|
||||
Bit Rate : 353k
|
||||
Sample Encoding: 16-bit Signed Integer PCM
|
||||
```
|
||||
|
||||
To export the checkpoint to onnx:
|
||||
|
||||
```bash
|
||||
# export the acoustic model to onnx
|
||||
|
||||
./matcha/export_onnx.py \
|
||||
--exp-dir ./matcha/exp-new-3 \
|
||||
--epoch 4000 \
|
||||
--tokens ./data/tokens.txt
|
||||
```
|
||||
|
||||
The above command generate the following files:
|
||||
|
||||
- model-steps-2.onnx
|
||||
- model-steps-3.onnx
|
||||
- model-steps-4.onnx
|
||||
- model-steps-5.onnx
|
||||
- model-steps-6.onnx
|
||||
|
||||
where the 2 in `model-steps-2.onnx` means it uses 2 steps for the ODE solver.
|
||||
|
||||
|
||||
To export the Hifigan vocoder to onnx, please use:
|
||||
|
||||
```bash
|
||||
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1
|
||||
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2
|
||||
wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3
|
||||
|
||||
python3 ./matcha/export_onnx_hifigan.py
|
||||
```
|
||||
|
||||
The above command generates 3 files:
|
||||
|
||||
- hifigan_v1.onnx
|
||||
- hifigan_v2.onnx
|
||||
- hifigan_v3.onnx
|
||||
|
||||
To use the generated onnx files to generate speech from text, please run:
|
||||
|
||||
```bash
|
||||
python3 ./matcha/onnx_pretrained.py \
|
||||
--acoustic-model ./model-steps-6.onnx \
|
||||
--vocoder ./hifigan_v1.onnx \
|
||||
--tokens ./data/tokens.txt \
|
||||
--input-text "Ask not what your country can do for you; ask what you can do for your country." \
|
||||
--output-wav ./matcha-epoch-4000-step6-hfigian-v1.wav
|
||||
```
|
||||
|
||||
```bash
|
||||
soxi ./matcha-epoch-4000-step6-hfigian-v1.wav
|
||||
|
||||
Input File : './matcha-epoch-4000-step6-hfigian-v1.wav'
|
||||
Channels : 1
|
||||
Sample Rate : 22050
|
||||
Precision : 16-bit
|
||||
Duration : 00:00:05.46 = 120320 samples ~ 409.252 CDDA sectors
|
||||
File Size : 241k
|
||||
Bit Rate : 353k
|
||||
Sample Encoding: 16-bit Signed Integer PCM
|
||||
```
|
||||
|
||||
https://github.com/user-attachments/assets/b7c197a6-3870-49c6-90ca-db4d3776869b
|
||||
|
||||
|
208
egs/ljspeech/TTS/local/compute_fbank_ljspeech.py
Executable file
208
egs/ljspeech/TTS/local/compute_fbank_ljspeech.py
Executable file
@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the LJSpeech dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
||||
from lhotse.audio import RecordingSet
|
||||
from lhotse.features.base import FeatureExtractor, register_extractor
|
||||
from lhotse.supervision import SupervisionSet
|
||||
from lhotse.utils import Seconds, compute_num_frames
|
||||
from matcha.audio import mel_spectrogram
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyFbankConfig:
|
||||
n_fft: int
|
||||
n_mels: int
|
||||
sampling_rate: int
|
||||
hop_length: int
|
||||
win_length: int
|
||||
f_min: float
|
||||
f_max: float
|
||||
|
||||
|
||||
@register_extractor
|
||||
class MyFbank(FeatureExtractor):
|
||||
|
||||
name = "MyFbank"
|
||||
config_type = MyFbankConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config=config)
|
||||
|
||||
@property
|
||||
def device(self) -> Union[str, torch.device]:
|
||||
return self.config.device
|
||||
|
||||
def feature_dim(self, sampling_rate: int) -> int:
|
||||
return self.config.n_mels
|
||||
|
||||
def extract(
|
||||
self,
|
||||
samples: np.ndarray,
|
||||
sampling_rate: int,
|
||||
) -> torch.Tensor:
|
||||
# Check for sampling rate compatibility.
|
||||
expected_sr = self.config.sampling_rate
|
||||
assert sampling_rate == expected_sr, (
|
||||
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||
f"got {sampling_rate}"
|
||||
)
|
||||
samples = torch.from_numpy(samples)
|
||||
assert samples.ndim == 2, samples.shape
|
||||
assert samples.shape[0] == 1, samples.shape
|
||||
|
||||
mel = (
|
||||
mel_spectrogram(
|
||||
samples,
|
||||
self.config.n_fft,
|
||||
self.config.n_mels,
|
||||
self.config.sampling_rate,
|
||||
self.config.hop_length,
|
||||
self.config.win_length,
|
||||
self.config.f_min,
|
||||
self.config.f_max,
|
||||
center=False,
|
||||
)
|
||||
.squeeze()
|
||||
.t()
|
||||
)
|
||||
|
||||
assert mel.ndim == 2, mel.shape
|
||||
assert mel.shape[1] == self.config.n_mels, mel.shape
|
||||
|
||||
num_frames = compute_num_frames(
|
||||
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
|
||||
)
|
||||
|
||||
if mel.shape[0] > num_frames:
|
||||
mel = mel[:num_frames]
|
||||
elif mel.shape[0] < num_frames:
|
||||
mel = mel.unsqueeze(0)
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||
).squeeze(0)
|
||||
|
||||
return mel.numpy()
|
||||
|
||||
@property
|
||||
def frame_shift(self) -> Seconds:
|
||||
return self.config.hop_length / self.config.sampling_rate
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-jobs",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank_ljspeech(num_jobs: int):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
if num_jobs < 1:
|
||||
num_jobs = os.cpu_count()
|
||||
|
||||
logging.info(f"num_jobs: {num_jobs}")
|
||||
logging.info(f"src_dir: {src_dir}")
|
||||
logging.info(f"output_dir: {output_dir}")
|
||||
config = MyFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=22050,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
|
||||
prefix = "ljspeech"
|
||||
suffix = "jsonl.gz"
|
||||
partition = "all"
|
||||
|
||||
recordings = load_manifest(
|
||||
src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet
|
||||
)
|
||||
supervisions = load_manifest(
|
||||
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
|
||||
)
|
||||
|
||||
extractor = MyFbank(config)
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{cuts_filename} already exists - skipping.")
|
||||
return
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=recordings, supervisions=supervisions
|
||||
)
|
||||
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(output_dir / cuts_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_parser().parse_args()
|
||||
compute_fbank_ljspeech(args.num_jobs)
|
84
egs/ljspeech/TTS/local/compute_fbank_statistics.py
Executable file
84
egs/ljspeech/TTS/local/compute_fbank_statistics.py
Executable file
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script compute the mean and std of the fbank features.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"manifest",
|
||||
type=Path,
|
||||
help="Path to the manifest file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"cmvn",
|
||||
type=Path,
|
||||
help="Path to the cmvn.json",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
manifest = args.manifest
|
||||
logging.info(
|
||||
f"Computing fbank mean and std for {manifest} and saving to {args.cmvn}"
|
||||
)
|
||||
|
||||
assert manifest.is_file(), f"{manifest} does not exist"
|
||||
cut_set = load_manifest_lazy(manifest)
|
||||
assert isinstance(cut_set, CutSet), type(cut_set)
|
||||
|
||||
feat_dim = cut_set[0].features.num_features
|
||||
num_frames = 0
|
||||
s = 0
|
||||
sq = 0
|
||||
for c in cut_set:
|
||||
f = torch.from_numpy(c.load_features())
|
||||
num_frames += f.shape[0]
|
||||
s += f.sum()
|
||||
sq += f.square().sum()
|
||||
|
||||
fbank_mean = s / (num_frames * feat_dim)
|
||||
fbank_var = sq / (num_frames * feat_dim) - fbank_mean * fbank_mean
|
||||
print("fbank var", fbank_var)
|
||||
fbank_std = fbank_var.sqrt()
|
||||
with open(args.cmvn, "w") as f:
|
||||
json.dump({"fbank_mean": fbank_mean.item(), "fbank_std": fbank_std.item()}, f)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -28,17 +28,33 @@ try:
|
||||
except ModuleNotFoundError as ex:
|
||||
raise RuntimeError(f"{ex}\nPlease run\n pip install espnet_tts_frontend\n")
|
||||
|
||||
import argparse
|
||||
|
||||
from lhotse import CutSet, load_manifest
|
||||
from piper_phonemize import phonemize_espeak
|
||||
|
||||
|
||||
def prepare_tokens_ljspeech():
|
||||
output_dir = Path("data/spectrogram")
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--in-out-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Input and output directory",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def prepare_tokens_ljspeech(in_out_dir):
|
||||
prefix = "ljspeech"
|
||||
suffix = "jsonl.gz"
|
||||
partition = "all"
|
||||
|
||||
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||
cut_set = load_manifest(in_out_dir / f"{prefix}_cuts_{partition}.{suffix}")
|
||||
|
||||
new_cuts = []
|
||||
for cut in cut_set:
|
||||
@ -56,11 +72,13 @@ def prepare_tokens_ljspeech():
|
||||
new_cuts.append(cut)
|
||||
|
||||
new_cut_set = CutSet.from_cuts(new_cuts)
|
||||
new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
|
||||
new_cut_set.to_file(in_out_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
prepare_tokens_ljspeech()
|
||||
args = get_parser().parse_args()
|
||||
|
||||
prepare_tokens_ljspeech(args.in_out_dir)
|
||||
|
@ -33,6 +33,7 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from compute_fbank_ljspeech import MyFbank
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset.speech_synthesis import validate_for_tts
|
||||
|
||||
|
21
egs/ljspeech/TTS/matcha/LICENSE
Normal file
21
egs/ljspeech/TTS/matcha/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Shivam Mehta
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
0
egs/ljspeech/TTS/matcha/__init__.py
Normal file
0
egs/ljspeech/TTS/matcha/__init__.py
Normal file
92
egs/ljspeech/TTS/matcha/audio.py
Normal file
92
egs/ljspeech/TTS/matcha/audio.py
Normal file
@ -0,0 +1,92 @@
|
||||
# This file is copied from
|
||||
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from scipy.io.wavfile import read
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def load_wav(full_path):
|
||||
sampling_rate, data = read(full_path)
|
||||
return data, sampling_rate
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def mel_spectrogram(
|
||||
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||
):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window # pylint: disable=global-statement
|
||||
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[str(fmax) + "_" + str(y.device)] = (
|
||||
torch.from_numpy(mel).float().to(y.device)
|
||||
)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[str(y.device)],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
||||
|
||||
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
1
egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py
Symbolic link
1
egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py
Symbolic link
@ -0,0 +1 @@
|
||||
../local/compute_fbank_ljspeech.py
|
196
egs/ljspeech/TTS/matcha/export_onnx.py
Executable file
196
egs/ljspeech/TTS/matcha/export_onnx.py
Executable file
@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script exports a Matcha-TTS model to ONNX.
|
||||
Note that the model outputs fbank. You need to use a vocoder to convert
|
||||
it to audio. See also ./export_onnx_hifigan.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=4000,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default="matcha/exp-new-3",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=Path,
|
||||
default="data/tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cmvn",
|
||||
type=str,
|
||||
default="data/fbank/cmvn.json",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
|
||||
while len(model.metadata_props):
|
||||
model.metadata_props.pop()
|
||||
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class ModelWrapper(torch.nn.Module):
|
||||
def __init__(self, model, num_steps: int = 5):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.num_steps = num_steps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lengths: torch.Tensor,
|
||||
temperature: torch.Tensor,
|
||||
length_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args: :
|
||||
x: (batch_size, num_tokens), torch.int64
|
||||
x_lengths: (batch_size,), torch.int64
|
||||
temperature: (1,), torch.float32
|
||||
length_scale (1,), torch.float32
|
||||
Returns:
|
||||
audio: (batch_size, num_samples)
|
||||
|
||||
"""
|
||||
mel = self.model.synthesise(
|
||||
x=x,
|
||||
x_lengths=x_lengths,
|
||||
n_timesteps=self.num_steps,
|
||||
temperature=temperature,
|
||||
length_scale=length_scale,
|
||||
)["mel"]
|
||||
# mel: (batch_size, feat_dim, num_frames)
|
||||
|
||||
return mel
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
|
||||
with open(params.cmvn) as f:
|
||||
stats = json.load(f)
|
||||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
|
||||
for num_steps in [2, 3, 4, 5, 6]:
|
||||
logging.info(f"num_steps: {num_steps}")
|
||||
wrapper = ModelWrapper(model, num_steps=num_steps)
|
||||
wrapper.eval()
|
||||
|
||||
# Use a large value so the rotary position embedding in the text
|
||||
# encoder has a large initial length
|
||||
x = torch.ones(1, 1000, dtype=torch.int64)
|
||||
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||
temperature = torch.tensor([1.0])
|
||||
length_scale = torch.tensor([1.0])
|
||||
|
||||
opset_version = 14
|
||||
filename = f"model-steps-{num_steps}.onnx"
|
||||
torch.onnx.export(
|
||||
wrapper,
|
||||
(x, x_lengths, temperature, length_scale),
|
||||
filename,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "x_length", "temperature", "length_scale"],
|
||||
output_names=["mel"],
|
||||
dynamic_axes={
|
||||
"x": {0: "N", 1: "L"},
|
||||
"x_length": {0: "N"},
|
||||
"mel": {0: "N", 2: "L"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "matcha-tts",
|
||||
"language": "English",
|
||||
"voice": "en-us",
|
||||
"has_espeak": 1,
|
||||
"n_speakers": 1,
|
||||
"sample_rate": 22050,
|
||||
"version": 1,
|
||||
"model_author": "icefall",
|
||||
"maintainer": "k2-fsa",
|
||||
"dataset": "LJ Speech",
|
||||
"num_ode_steps": num_steps,
|
||||
}
|
||||
add_meta_data(filename=filename, meta_data=meta_data)
|
||||
print(meta_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
110
egs/ljspeech/TTS/matcha/export_onnx_hifigan.py
Executable file
110
egs/ljspeech/TTS/matcha/export_onnx_hifigan.py
Executable file
@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from inference import load_vocoder
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
|
||||
while len(model.metadata_props):
|
||||
model.metadata_props.pop()
|
||||
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class ModelWrapper(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
mel: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args: :
|
||||
mel: (batch_size, feat_dim, num_frames), torch.float32
|
||||
Returns:
|
||||
audio: (batch_size, num_samples), torch.float32
|
||||
"""
|
||||
audio = self.model(mel).clamp(-1, 1).squeeze(1)
|
||||
return audio
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
# Please go to
|
||||
# https://github.com/csukuangfj/models/tree/master/hifigan
|
||||
# to download the following files
|
||||
model_filenames = ["./generator_v1", "./generator_v2", "./generator_v3"]
|
||||
|
||||
for f in model_filenames:
|
||||
logging.info(f)
|
||||
if not Path(f).is_file():
|
||||
logging.info(f"Skipping {f} since {f} does not exist")
|
||||
continue
|
||||
model = load_vocoder(f)
|
||||
wrapper = ModelWrapper(model)
|
||||
wrapper.eval()
|
||||
num_param = sum([p.numel() for p in wrapper.parameters()])
|
||||
logging.info(f"{f}: Number of parameters: {num_param}")
|
||||
|
||||
# Use a large value so the rotary position embedding in the text
|
||||
# encoder has a large initial length
|
||||
x = torch.ones(1, 80, 100000, dtype=torch.float32)
|
||||
opset_version = 14
|
||||
suffix = f.split("_")[-1]
|
||||
filename = f"hifigan_{suffix}.onnx"
|
||||
torch.onnx.export(
|
||||
wrapper,
|
||||
x,
|
||||
filename,
|
||||
opset_version=opset_version,
|
||||
input_names=["mel"],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"mel": {0: "N", 2: "L"},
|
||||
"audio": {0: "N", 1: "L"},
|
||||
},
|
||||
)
|
||||
|
||||
meta_data = {
|
||||
"model_type": "hifigan",
|
||||
"model_filename": f.split("/")[-1],
|
||||
"sample_rate": 22050,
|
||||
"version": 1,
|
||||
"model_author": "jik876",
|
||||
"maintainer": "k2-fsa",
|
||||
"dataset": "LJ Speech",
|
||||
"url1": "https://github.com/jik876/hifi-gan",
|
||||
"url2": "https://github.com/csukuangfj/models/tree/master/hifigan",
|
||||
}
|
||||
add_meta_data(filename=filename, meta_data=meta_data)
|
||||
print(meta_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
21
egs/ljspeech/TTS/matcha/hifigan/LICENSE
Normal file
21
egs/ljspeech/TTS/matcha/hifigan/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 Jungil Kong
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
101
egs/ljspeech/TTS/matcha/hifigan/README.md
Normal file
101
egs/ljspeech/TTS/matcha/hifigan/README.md
Normal file
@ -0,0 +1,101 @@
|
||||
# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
|
||||
|
||||
### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
|
||||
|
||||
In our [paper](https://arxiv.org/abs/2010.05646),
|
||||
we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.<br/>
|
||||
We provide our implementation and pretrained models as open source in this repository.
|
||||
|
||||
**Abstract :**
|
||||
Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
|
||||
Although such methods improve the sampling efficiency and memory usage,
|
||||
their sample quality has not yet reached that of autoregressive and flow-based generative models.
|
||||
In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
|
||||
As speech audio consists of sinusoidal signals with various periods,
|
||||
we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
|
||||
A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
|
||||
demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
|
||||
real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
|
||||
speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
|
||||
faster than real-time on CPU with comparable quality to an autoregressive counterpart.
|
||||
|
||||
Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
|
||||
|
||||
## Pre-requisites
|
||||
|
||||
1. Python >= 3.6
|
||||
2. Clone this repository.
|
||||
3. Install python requirements. Please refer [requirements.txt](requirements.txt)
|
||||
4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
|
||||
And move all wav files to `LJSpeech-1.1/wavs`
|
||||
|
||||
## Training
|
||||
|
||||
```
|
||||
python train.py --config config_v1.json
|
||||
```
|
||||
|
||||
To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.<br>
|
||||
Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.<br>
|
||||
You can change the path by adding `--checkpoint_path` option.
|
||||
|
||||
Validation loss during training with V1 generator.<br>
|
||||

|
||||
|
||||
## Pretrained Model
|
||||
|
||||
You can also use pretrained models we provide.<br/>
|
||||
[Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)<br/>
|
||||
Details of each folder are as in follows:
|
||||
|
||||
| Folder Name | Generator | Dataset | Fine-Tuned |
|
||||
| ------------ | --------- | --------- | ------------------------------------------------------ |
|
||||
| LJ_V1 | V1 | LJSpeech | No |
|
||||
| LJ_V2 | V2 | LJSpeech | No |
|
||||
| LJ_V3 | V3 | LJSpeech | No |
|
||||
| LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
|
||||
| LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
|
||||
| LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
|
||||
| VCTK_V1 | V1 | VCTK | No |
|
||||
| VCTK_V2 | V2 | VCTK | No |
|
||||
| VCTK_V3 | V3 | VCTK | No |
|
||||
| UNIVERSAL_V1 | V1 | Universal | No |
|
||||
|
||||
We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
|
||||
|
||||
## Fine-Tuning
|
||||
|
||||
1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.<br/>
|
||||
The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.<br/>
|
||||
Example:
|
||||
` Audio File : LJ001-0001.wav
|
||||
Mel-Spectrogram File : LJ001-0001.npy`
|
||||
2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.<br/>
|
||||
3. Run the following command.
|
||||
```
|
||||
python train.py --fine_tuning True --config config_v1.json
|
||||
```
|
||||
For other command line options, please refer to the training section.
|
||||
|
||||
## Inference from wav file
|
||||
|
||||
1. Make `test_files` directory and copy wav files into the directory.
|
||||
2. Run the following command.
|
||||
` python inference.py --checkpoint_file [generator checkpoint file path]`
|
||||
Generated wav files are saved in `generated_files` by default.<br>
|
||||
You can change the path by adding `--output_dir` option.
|
||||
|
||||
## Inference for end-to-end speech synthesis
|
||||
|
||||
1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.<br>
|
||||
You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
|
||||
[Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
|
||||
2. Run the following command.
|
||||
` python inference_e2e.py --checkpoint_file [generator checkpoint file path]`
|
||||
Generated wav files are saved in `generated_files_from_mel` by default.<br>
|
||||
You can change the path by adding `--output_dir` option.
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
|
||||
and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.
|
0
egs/ljspeech/TTS/matcha/hifigan/__init__.py
Normal file
0
egs/ljspeech/TTS/matcha/hifigan/__init__.py
Normal file
100
egs/ljspeech/TTS/matcha/hifigan/config.py
Normal file
100
egs/ljspeech/TTS/matcha/hifigan/config.py
Normal file
@ -0,0 +1,100 @@
|
||||
v1 = {
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0004,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.999,
|
||||
"seed": 1234,
|
||||
"upsample_rates": [8, 8, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 16, 4, 4],
|
||||
"upsample_initial_channel": 512,
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"resblock_initial_channel": 256,
|
||||
"segment_size": 8192,
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
"sampling_rate": 22050,
|
||||
"fmin": 0,
|
||||
"fmax": 8000,
|
||||
"fmax_loss": None,
|
||||
"num_workers": 4,
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1,
|
||||
},
|
||||
}
|
||||
|
||||
# See https://drive.google.com/drive/folders/1bB1tnGIxRN-edlf6k2Rmi1gNCK9Cpcvf
|
||||
v2 = {
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0002,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.999,
|
||||
"seed": 1234,
|
||||
"upsample_rates": [8, 8, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 16, 4, 4],
|
||||
"upsample_initial_channel": 128,
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"resblock_initial_channel": 64,
|
||||
"segment_size": 8192,
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
"sampling_rate": 22050,
|
||||
"fmin": 0,
|
||||
"fmax": 8000,
|
||||
"fmax_loss": None,
|
||||
"num_workers": 4,
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1,
|
||||
},
|
||||
}
|
||||
|
||||
# See https://drive.google.com/drive/folders/1KKvuJTLp_gZXC8lug7H_lSXct38_3kx1
|
||||
v3 = {
|
||||
"resblock": "2",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 16,
|
||||
"learning_rate": 0.0002,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.999,
|
||||
"seed": 1234,
|
||||
"upsample_rates": [8, 8, 4],
|
||||
"upsample_kernel_sizes": [16, 16, 8],
|
||||
"upsample_initial_channel": 256,
|
||||
"resblock_kernel_sizes": [3, 5, 7],
|
||||
"resblock_dilation_sizes": [[1, 2], [2, 6], [3, 12]],
|
||||
"resblock_initial_channel": 128,
|
||||
"segment_size": 8192,
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
"sampling_rate": 22050,
|
||||
"fmin": 0,
|
||||
"fmax": 8000,
|
||||
"fmax_loss": None,
|
||||
"num_workers": 4,
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1,
|
||||
},
|
||||
}
|
71
egs/ljspeech/TTS/matcha/hifigan/denoiser.py
Normal file
71
egs/ljspeech/TTS/matcha/hifigan/denoiser.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
|
||||
|
||||
"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
|
||||
import torch
|
||||
|
||||
|
||||
class Denoiser(torch.nn.Module):
|
||||
"""Removes model bias from audio produced with waveglow"""
|
||||
|
||||
def __init__(
|
||||
self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"
|
||||
):
|
||||
super().__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = int(filter_length / n_overlap)
|
||||
self.win_length = win_length
|
||||
|
||||
dtype, device = (
|
||||
next(vocoder.parameters()).dtype,
|
||||
next(vocoder.parameters()).device,
|
||||
)
|
||||
self.device = device
|
||||
if mode == "zeros":
|
||||
mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
|
||||
elif mode == "normal":
|
||||
mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
|
||||
else:
|
||||
raise Exception(f"Mode {mode} if not supported")
|
||||
|
||||
def stft_fn(audio, n_fft, hop_length, win_length, window):
|
||||
spec = torch.stft(
|
||||
audio,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
return_complex=True,
|
||||
)
|
||||
spec = torch.view_as_real(spec)
|
||||
return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(
|
||||
spec[..., -1], spec[..., 0]
|
||||
)
|
||||
|
||||
self.stft = lambda x: stft_fn(
|
||||
audio=x,
|
||||
n_fft=self.filter_length,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=torch.hann_window(self.win_length, device=device),
|
||||
)
|
||||
self.istft = lambda x, y: torch.istft(
|
||||
torch.complex(x * torch.cos(y), x * torch.sin(y)),
|
||||
n_fft=self.filter_length,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=torch.hann_window(self.win_length, device=device),
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
bias_audio = vocoder(mel_input).float().squeeze(0)
|
||||
bias_spec, _ = self.stft(bias_audio)
|
||||
|
||||
self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, audio, strength=0.0005):
|
||||
audio_spec, audio_angles = self.stft(audio)
|
||||
audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength
|
||||
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
|
||||
audio_denoised = self.istft(audio_spec_denoised, audio_angles)
|
||||
return audio_denoised
|
17
egs/ljspeech/TTS/matcha/hifigan/env.py
Normal file
17
egs/ljspeech/TTS/matcha/hifigan/env.py
Normal file
@ -0,0 +1,17 @@
|
||||
""" from https://github.com/jik876/hifi-gan """
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def build_env(config, config_name, path):
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
shutil.copyfile(config, os.path.join(path, config_name))
|
245
egs/ljspeech/TTS/matcha/hifigan/meldataset.py
Normal file
245
egs/ljspeech/TTS/matcha/hifigan/meldataset.py
Normal file
@ -0,0 +1,245 @@
|
||||
""" from https://github.com/jik876/hifi-gan """
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from librosa.util import normalize
|
||||
from scipy.io.wavfile import read
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def load_wav(full_path):
|
||||
sampling_rate, data = read(full_path)
|
||||
return data, sampling_rate
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def mel_spectrogram(
|
||||
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||
):
|
||||
if torch.min(y) < -1.0:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window # pylint: disable=global-statement
|
||||
if fmax not in mel_basis:
|
||||
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
||||
mel_basis[str(fmax) + "_" + str(y.device)] = (
|
||||
torch.from_numpy(mel).float().to(y.device)
|
||||
)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.view_as_real(
|
||||
torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[str(y.device)],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
||||
|
||||
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def get_dataset_filelist(a):
|
||||
with open(a.input_training_file, encoding="utf-8") as fi:
|
||||
training_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
]
|
||||
|
||||
with open(a.input_validation_file, encoding="utf-8") as fi:
|
||||
validation_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
]
|
||||
return training_files, validation_files
|
||||
|
||||
|
||||
class MelDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
training_files,
|
||||
segment_size,
|
||||
n_fft,
|
||||
num_mels,
|
||||
hop_size,
|
||||
win_size,
|
||||
sampling_rate,
|
||||
fmin,
|
||||
fmax,
|
||||
split=True,
|
||||
shuffle=True,
|
||||
n_cache_reuse=1,
|
||||
device=None,
|
||||
fmax_loss=None,
|
||||
fine_tuning=False,
|
||||
base_mels_path=None,
|
||||
):
|
||||
self.audio_files = training_files
|
||||
random.seed(1234)
|
||||
if shuffle:
|
||||
random.shuffle(self.audio_files)
|
||||
self.segment_size = segment_size
|
||||
self.sampling_rate = sampling_rate
|
||||
self.split = split
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.hop_size = hop_size
|
||||
self.win_size = win_size
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.fmax_loss = fmax_loss
|
||||
self.cached_wav = None
|
||||
self.n_cache_reuse = n_cache_reuse
|
||||
self._cache_ref_count = 0
|
||||
self.device = device
|
||||
self.fine_tuning = fine_tuning
|
||||
self.base_mels_path = base_mels_path
|
||||
|
||||
def __getitem__(self, index):
|
||||
filename = self.audio_files[index]
|
||||
if self._cache_ref_count == 0:
|
||||
audio, sampling_rate = load_wav(filename)
|
||||
audio = audio / MAX_WAV_VALUE
|
||||
if not self.fine_tuning:
|
||||
audio = normalize(audio) * 0.95
|
||||
self.cached_wav = audio
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(
|
||||
f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR"
|
||||
)
|
||||
self._cache_ref_count = self.n_cache_reuse
|
||||
else:
|
||||
audio = self.cached_wav
|
||||
self._cache_ref_count -= 1
|
||||
|
||||
audio = torch.FloatTensor(audio)
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
if not self.fine_tuning:
|
||||
if self.split:
|
||||
if audio.size(1) >= self.segment_size:
|
||||
max_audio_start = audio.size(1) - self.segment_size
|
||||
audio_start = random.randint(0, max_audio_start)
|
||||
audio = audio[:, audio_start : audio_start + self.segment_size]
|
||||
else:
|
||||
audio = torch.nn.functional.pad(
|
||||
audio, (0, self.segment_size - audio.size(1)), "constant"
|
||||
)
|
||||
|
||||
mel = mel_spectrogram(
|
||||
audio,
|
||||
self.n_fft,
|
||||
self.num_mels,
|
||||
self.sampling_rate,
|
||||
self.hop_size,
|
||||
self.win_size,
|
||||
self.fmin,
|
||||
self.fmax,
|
||||
center=False,
|
||||
)
|
||||
else:
|
||||
mel = np.load(
|
||||
os.path.join(
|
||||
self.base_mels_path,
|
||||
os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
|
||||
)
|
||||
)
|
||||
mel = torch.from_numpy(mel)
|
||||
|
||||
if len(mel.shape) < 3:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
if self.split:
|
||||
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
||||
|
||||
if audio.size(1) >= self.segment_size:
|
||||
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
||||
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
||||
audio = audio[
|
||||
:,
|
||||
mel_start
|
||||
* self.hop_size : (mel_start + frames_per_seg)
|
||||
* self.hop_size,
|
||||
]
|
||||
else:
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, frames_per_seg - mel.size(2)), "constant"
|
||||
)
|
||||
audio = torch.nn.functional.pad(
|
||||
audio, (0, self.segment_size - audio.size(1)), "constant"
|
||||
)
|
||||
|
||||
mel_loss = mel_spectrogram(
|
||||
audio,
|
||||
self.n_fft,
|
||||
self.num_mels,
|
||||
self.sampling_rate,
|
||||
self.hop_size,
|
||||
self.win_size,
|
||||
self.fmin,
|
||||
self.fmax_loss,
|
||||
center=False,
|
||||
)
|
||||
|
||||
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audio_files)
|
406
egs/ljspeech/TTS/matcha/hifigan/models.py
Normal file
406
egs/ljspeech/TTS/matcha/hifigan/models.py
Normal file
@ -0,0 +1,406 @@
|
||||
""" from https://github.com/jik876/hifi-gan """
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
|
||||
from .xutils import get_padding, init_weights
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
self.h = h
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super().__init__()
|
||||
self.h = h
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
)
|
||||
),
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self, h):
|
||||
super().__init__()
|
||||
self.h = h
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
|
||||
)
|
||||
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
h.upsample_initial_channel // (2**i),
|
||||
h.upsample_initial_channel // (2 ** (i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(
|
||||
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
||||
):
|
||||
self.resblocks.append(resblock(h, ch, k, d))
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
self.period = period
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(
|
||||
Conv2d(
|
||||
1,
|
||||
32,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
32,
|
||||
128,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
128,
|
||||
512,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
512,
|
||||
1024,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(2),
|
||||
DiscriminatorP(3),
|
||||
DiscriminatorP(5),
|
||||
DiscriminatorP(7),
|
||||
DiscriminatorP(11),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for _, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorS(use_spectral_norm=True),
|
||||
DiscriminatorS(),
|
||||
DiscriminatorS(),
|
||||
]
|
||||
)
|
||||
self.meanpools = nn.ModuleList(
|
||||
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
|
||||
)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
y = self.meanpools[i - 1](y)
|
||||
y_hat = self.meanpools[i - 1](y_hat)
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_g):
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss * 2
|
||||
|
||||
|
||||
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean((1 - dr) ** 2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_loss(disc_outputs):
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
60
egs/ljspeech/TTS/matcha/hifigan/xutils.py
Normal file
60
egs/ljspeech/TTS/matcha/hifigan/xutils.py
Normal file
@ -0,0 +1,60 @@
|
||||
""" from https://github.com/jik876/hifi-gan """
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
import torch
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram):
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def apply_weight_norm(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
weight_norm(m)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print(f"Loading '{filepath}'")
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
def save_checkpoint(filepath, obj):
|
||||
print(f"Saving checkpoint to {filepath}")
|
||||
torch.save(obj, filepath)
|
||||
print("Complete.")
|
||||
|
||||
|
||||
def scan_checkpoint(cp_dir, prefix):
|
||||
pattern = os.path.join(cp_dir, prefix + "????????")
|
||||
cp_list = glob.glob(pattern)
|
||||
if len(cp_list) == 0:
|
||||
return None
|
||||
return sorted(cp_list)[-1]
|
199
egs/ljspeech/TTS/matcha/inference.py
Executable file
199
egs/ljspeech/TTS/matcha/inference.py
Executable file
@ -0,0 +1,199 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from matcha.hifigan.config import v1, v2, v3
|
||||
from matcha.hifigan.denoiser import Denoiser
|
||||
from matcha.hifigan.models import Generator as HiFiGAN
|
||||
from tokenizer import Tokenizer
|
||||
from train import get_model, get_params
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=4000,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default="matcha/exp-new-3",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
type=Path,
|
||||
default="./generator_v1",
|
||||
help="Path to the vocoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=Path,
|
||||
default="data/tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cmvn",
|
||||
type=str,
|
||||
default="data/fbank/cmvn.json",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-text",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-wav",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The filename of the wave to save the generated speech",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def load_vocoder(checkpoint_path):
|
||||
checkpoint_path = str(checkpoint_path)
|
||||
if checkpoint_path.endswith("v1"):
|
||||
h = AttributeDict(v1)
|
||||
elif checkpoint_path.endswith("v2"):
|
||||
h = AttributeDict(v2)
|
||||
elif checkpoint_path.endswith("v3"):
|
||||
h = AttributeDict(v3)
|
||||
else:
|
||||
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")
|
||||
|
||||
hifigan = HiFiGAN(h).to("cpu")
|
||||
hifigan.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location="cpu")["generator"]
|
||||
)
|
||||
_ = hifigan.eval()
|
||||
hifigan.remove_weight_norm()
|
||||
return hifigan
|
||||
|
||||
|
||||
def to_waveform(mel, vocoder, denoiser):
|
||||
audio = vocoder(mel).clamp(-1, 1)
|
||||
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
|
||||
return audio.cpu().squeeze()
|
||||
|
||||
|
||||
def process_text(text: str, tokenizer):
|
||||
x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||
x = torch.tensor(x, dtype=torch.long)
|
||||
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
|
||||
return {"x_orig": text, "x": x, "x_lengths": x_lengths}
|
||||
|
||||
|
||||
def synthesise(
|
||||
model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None
|
||||
):
|
||||
text_processed = process_text(text, tokenizer)
|
||||
start_t = dt.datetime.now()
|
||||
output = model.synthesise(
|
||||
text_processed["x"],
|
||||
text_processed["x_lengths"],
|
||||
n_timesteps=n_timesteps,
|
||||
temperature=temperature,
|
||||
spks=spks,
|
||||
length_scale=length_scale,
|
||||
)
|
||||
# merge everything to one dict
|
||||
output.update({"start_t": start_t, **text_processed})
|
||||
return output
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
|
||||
with open(params.cmvn) as f:
|
||||
stats = json.load(f)
|
||||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file():
|
||||
raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist")
|
||||
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.eval()
|
||||
|
||||
if not Path(params.vocoder).is_file():
|
||||
raise ValueError(f"{params.vocoder} does not exist")
|
||||
|
||||
vocoder = load_vocoder(params.vocoder)
|
||||
denoiser = Denoiser(vocoder, mode="zeros")
|
||||
|
||||
# Number of ODE Solver steps
|
||||
n_timesteps = 2
|
||||
|
||||
# Changes to the speaking rate
|
||||
length_scale = 1.0
|
||||
|
||||
# Sampling temperature
|
||||
temperature = 0.667
|
||||
|
||||
output = synthesise(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
n_timesteps=n_timesteps,
|
||||
text=params.input_text,
|
||||
length_scale=length_scale,
|
||||
temperature=temperature,
|
||||
)
|
||||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
||||
|
||||
sf.write(params.output_wav, output["waveform"], 22050, "PCM_16")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
97
egs/ljspeech/TTS/matcha/model.py
Normal file
97
egs/ljspeech/TTS/matcha/model.py
Normal file
@ -0,0 +1,97 @@
|
||||
# This file is copied from
|
||||
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/model.py
|
||||
""" from https://github.com/jaywalnut310/glow-tts """
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
||||
factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
|
||||
length = (length / factor).ceil() * factor
|
||||
if not torch.onnx.is_in_onnx_export():
|
||||
return length.int().item()
|
||||
else:
|
||||
return length
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
inverted_shape = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in inverted_shape for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
device = duration.device
|
||||
|
||||
b, t_x, t_y = mask.shape
|
||||
cum_duration = torch.cumsum(duration, 1)
|
||||
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = (
|
||||
path
|
||||
- torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[
|
||||
:, :-1
|
||||
]
|
||||
)
|
||||
path = path * mask
|
||||
return path
|
||||
|
||||
|
||||
def duration_loss(logw, logw_, lengths):
|
||||
loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
|
||||
return loss
|
||||
|
||||
|
||||
def normalize(data, mu, std):
|
||||
if not isinstance(mu, (float, int)):
|
||||
if isinstance(mu, list):
|
||||
mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
|
||||
elif isinstance(mu, torch.Tensor):
|
||||
mu = mu.to(data.device)
|
||||
elif isinstance(mu, np.ndarray):
|
||||
mu = torch.from_numpy(mu).to(data.device)
|
||||
mu = mu.unsqueeze(-1)
|
||||
|
||||
if not isinstance(std, (float, int)):
|
||||
if isinstance(std, list):
|
||||
std = torch.tensor(std, dtype=data.dtype, device=data.device)
|
||||
elif isinstance(std, torch.Tensor):
|
||||
std = std.to(data.device)
|
||||
elif isinstance(std, np.ndarray):
|
||||
std = torch.from_numpy(std).to(data.device)
|
||||
std = std.unsqueeze(-1)
|
||||
|
||||
return (data - mu) / std
|
||||
|
||||
|
||||
def denormalize(data, mu, std):
|
||||
if not isinstance(mu, float):
|
||||
if isinstance(mu, list):
|
||||
mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
|
||||
elif isinstance(mu, torch.Tensor):
|
||||
mu = mu.to(data.device)
|
||||
elif isinstance(mu, np.ndarray):
|
||||
mu = torch.from_numpy(mu).to(data.device)
|
||||
mu = mu.unsqueeze(-1)
|
||||
|
||||
if not isinstance(std, float):
|
||||
if isinstance(std, list):
|
||||
std = torch.tensor(std, dtype=data.dtype, device=data.device)
|
||||
elif isinstance(std, torch.Tensor):
|
||||
std = std.to(data.device)
|
||||
elif isinstance(std, np.ndarray):
|
||||
std = torch.from_numpy(std).to(data.device)
|
||||
std = std.unsqueeze(-1)
|
||||
|
||||
return data * std + mu
|
3
egs/ljspeech/TTS/matcha/models/README.md
Normal file
3
egs/ljspeech/TTS/matcha/models/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# Introduction
|
||||
Files in this folder are copied from
|
||||
https://github.com/shivammehta25/Matcha-TTS/tree/main/matcha/models
|
0
egs/ljspeech/TTS/matcha/models/__init__.py
Normal file
0
egs/ljspeech/TTS/matcha/models/__init__.py
Normal file
459
egs/ljspeech/TTS/matcha/models/components/decoder.py
Normal file
459
egs/ljspeech/TTS/matcha/models/components/decoder.py
Normal file
@ -0,0 +1,459 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from conformer import ConformerBlock
|
||||
from diffusers.models.activations import get_activation
|
||||
from einops import pack, rearrange, repeat
|
||||
from matcha.models.components.transformer import BasicTransformerBlock
|
||||
|
||||
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
if x.ndim < 1:
|
||||
x = x.unsqueeze(0)
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class Block1D(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, groups=8):
|
||||
super().__init__()
|
||||
self.block = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
||||
torch.nn.GroupNorm(groups, dim_out),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x, mask):
|
||||
output = self.block(x * mask)
|
||||
return output * mask
|
||||
|
||||
|
||||
class ResnetBlock1D(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
||||
super().__init__()
|
||||
self.mlp = torch.nn.Sequential(
|
||||
nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)
|
||||
)
|
||||
|
||||
self.block1 = Block1D(dim, dim_out, groups=groups)
|
||||
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
||||
|
||||
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
||||
|
||||
def forward(self, x, mask, time_emb):
|
||||
h = self.block1(x, mask)
|
||||
h += self.mlp(time_emb).unsqueeze(-1)
|
||||
h = self.block2(h, mask)
|
||||
output = h + self.res_conv(x * mask)
|
||||
return output
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = get_activation(act_fn)
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
else:
|
||||
self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
use_conv=False,
|
||||
use_conv_transpose=True,
|
||||
out_channels=None,
|
||||
name="conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(inputs)
|
||||
|
||||
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
outputs = self.conv(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class ConformerWrapper(ConformerBlock):
|
||||
def __init__( # pylint: disable=useless-super-delegation
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
ff_mult=4,
|
||||
conv_expansion_factor=2,
|
||||
conv_kernel_size=31,
|
||||
attn_dropout=0,
|
||||
ff_dropout=0,
|
||||
conv_dropout=0,
|
||||
conv_causal=False,
|
||||
):
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
ff_mult=ff_mult,
|
||||
conv_expansion_factor=conv_expansion_factor,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
attn_dropout=attn_dropout,
|
||||
ff_dropout=ff_dropout,
|
||||
conv_dropout=conv_dropout,
|
||||
conv_causal=conv_causal,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
timestep=None,
|
||||
):
|
||||
return super().forward(x=hidden_states, mask=attention_mask.bool())
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
channels=(256, 256),
|
||||
dropout=0.05,
|
||||
attention_head_dim=64,
|
||||
n_blocks=1,
|
||||
num_mid_blocks=2,
|
||||
num_heads=4,
|
||||
act_fn="snake",
|
||||
down_block_type="transformer",
|
||||
mid_block_type="transformer",
|
||||
up_block_type="transformer",
|
||||
):
|
||||
super().__init__()
|
||||
channels = tuple(channels)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
||||
time_embed_dim = channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=in_channels,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn="silu",
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
output_channel = in_channels
|
||||
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
||||
input_channel = output_channel
|
||||
output_channel = channels[i]
|
||||
is_last = i == len(channels) - 1
|
||||
resnet = ResnetBlock1D(
|
||||
dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
|
||||
)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
down_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
downsample = (
|
||||
Downsample1D(output_channel)
|
||||
if not is_last
|
||||
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
|
||||
self.down_blocks.append(
|
||||
nn.ModuleList([resnet, transformer_blocks, downsample])
|
||||
)
|
||||
|
||||
for i in range(num_mid_blocks):
|
||||
input_channel = channels[-1]
|
||||
out_channels = channels[-1]
|
||||
|
||||
resnet = ResnetBlock1D(
|
||||
dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
|
||||
)
|
||||
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
mid_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
||||
|
||||
channels = channels[::-1] + (channels[0],)
|
||||
for i in range(len(channels) - 1):
|
||||
input_channel = channels[i]
|
||||
output_channel = channels[i + 1]
|
||||
is_last = i == len(channels) - 2
|
||||
|
||||
resnet = ResnetBlock1D(
|
||||
dim=2 * input_channel,
|
||||
dim_out=output_channel,
|
||||
time_emb_dim=time_embed_dim,
|
||||
)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
up_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
upsample = (
|
||||
Upsample1D(output_channel, use_conv_transpose=True)
|
||||
if not is_last
|
||||
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
|
||||
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
||||
|
||||
self.final_block = Block1D(channels[-1], channels[-1])
|
||||
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||
|
||||
self.initialize_weights()
|
||||
# nn.init.normal_(self.final_proj.weight)
|
||||
|
||||
@staticmethod
|
||||
def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
|
||||
if block_type == "conformer":
|
||||
block = ConformerWrapper(
|
||||
dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_heads,
|
||||
ff_mult=1,
|
||||
conv_expansion_factor=2,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
conv_dropout=dropout,
|
||||
conv_kernel_size=31,
|
||||
)
|
||||
elif block_type == "transformer":
|
||||
block = BasicTransformerBlock(
|
||||
dim=dim,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type {block_type}")
|
||||
|
||||
return block
|
||||
|
||||
def initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (_type_): shape (batch_size, 1, time)
|
||||
t (_type_): shape (batch_size)
|
||||
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (_type_, optional): placeholder for future use. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
t = self.time_embeddings(t)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
x = pack([x, mu], "b * t")[0]
|
||||
|
||||
if spks is not None:
|
||||
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
||||
x = pack([x, spks], "b * t")[0]
|
||||
|
||||
hiddens = []
|
||||
masks = [mask]
|
||||
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||
mask_down = masks[-1]
|
||||
x = resnet(x, mask_down, t)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
mask_down = rearrange(mask_down, "b 1 t -> b t")
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=mask_down,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
mask_down = rearrange(mask_down, "b t -> b 1 t")
|
||||
hiddens.append(x) # Save hidden states for skip connections
|
||||
x = downsample(x * mask_down)
|
||||
masks.append(mask_down[:, :, ::2])
|
||||
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
|
||||
for resnet, transformer_blocks in self.mid_blocks:
|
||||
x = resnet(x, mask_mid, t)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
mask_mid = rearrange(mask_mid, "b 1 t -> b t")
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=mask_mid,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
mask_mid = rearrange(mask_mid, "b t -> b 1 t")
|
||||
|
||||
for resnet, transformer_blocks, upsample in self.up_blocks:
|
||||
mask_up = masks.pop()
|
||||
x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
mask_up = rearrange(mask_up, "b 1 t -> b t")
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=mask_up,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
mask_up = rearrange(mask_up, "b t -> b 1 t")
|
||||
x = upsample(x * mask_up)
|
||||
|
||||
x = self.final_block(x, mask_up)
|
||||
output = self.final_proj(x * mask_up)
|
||||
|
||||
return output * mask
|
140
egs/ljspeech/TTS/matcha/models/components/flow_matching.py
Normal file
140
egs/ljspeech/TTS/matcha/models/components/flow_matching.py
Normal file
@ -0,0 +1,140 @@
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matcha.models.components.decoder import Decoder
|
||||
|
||||
|
||||
class BASECFM(torch.nn.Module, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
n_feats,
|
||||
cfm_params,
|
||||
n_spks=1,
|
||||
spk_emb_dim=128,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_feats = n_feats
|
||||
self.n_spks = n_spks
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.solver = cfm_params.solver
|
||||
if hasattr(cfm_params, "sigma_min"):
|
||||
self.sigma_min = cfm_params.sigma_min
|
||||
else:
|
||||
self.sigma_min = 1e-4
|
||||
|
||||
self.estimator = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
z = torch.randn_like(mu) * temperature
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
||||
return self.solve_euler(
|
||||
z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
|
||||
)
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
|
||||
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
||||
# Or in future might add like a return_all_steps flag
|
||||
sol = []
|
||||
|
||||
for step in range(1, len(t_span)):
|
||||
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
||||
|
||||
x = x + dt * dphi_dt
|
||||
t = t + dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t_span[step + 1] - t
|
||||
|
||||
return sol[-1]
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
|
||||
Args:
|
||||
x1 (torch.Tensor): Target
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): target mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
|
||||
Returns:
|
||||
loss: conditional flow matching loss
|
||||
y: conditional flow
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, _, t = mu.shape
|
||||
|
||||
# random timestep
|
||||
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
||||
# sample noise p(x_0)
|
||||
z = torch.randn_like(x1)
|
||||
|
||||
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
||||
u = x1 - (1 - self.sigma_min) * z
|
||||
|
||||
loss = F.mse_loss(
|
||||
self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum"
|
||||
) / (torch.sum(mask) * u.shape[1])
|
||||
return loss, y
|
||||
|
||||
|
||||
class CFM(BASECFM):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channel,
|
||||
cfm_params,
|
||||
decoder_params,
|
||||
n_spks=1,
|
||||
spk_emb_dim=64,
|
||||
):
|
||||
super().__init__(
|
||||
n_feats=in_channels,
|
||||
cfm_params=cfm_params,
|
||||
n_spks=n_spks,
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
)
|
||||
|
||||
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = Decoder(
|
||||
in_channels=in_channels, out_channels=out_channel, **decoder_params
|
||||
)
|
447
egs/ljspeech/TTS/matcha/models/components/text_encoder.py
Normal file
447
egs/ljspeech/TTS/matcha/models/components/text_encoder.py
Normal file
@ -0,0 +1,447 @@
|
||||
""" from https://github.com/jaywalnut310/glow-tts """
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from matcha.model import sequence_mask
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-4):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
||||
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
n_dims = len(x.shape)
|
||||
mean = torch.mean(x, 1, keepdim=True)
|
||||
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
||||
|
||||
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
||||
|
||||
shape = [1, -1] + [1] * (n_dims - 2)
|
||||
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
||||
return x
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
n_layers,
|
||||
p_dropout,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.conv_layers = torch.nn.ModuleList()
|
||||
self.norm_layers = torch.nn.ModuleList()
|
||||
self.conv_layers.append(
|
||||
torch.nn.Conv1d(
|
||||
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = torch.nn.Sequential(
|
||||
torch.nn.ReLU(), torch.nn.Dropout(p_dropout)
|
||||
)
|
||||
for _ in range(n_layers - 1):
|
||||
self.conv_layers.append(
|
||||
torch.nn.Conv1d(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x_org = x
|
||||
for i in range(self.n_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x)
|
||||
x = self.relu_drop(x)
|
||||
x = x_org + self.proj(x)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
self.conv_1 = torch.nn.Conv1d(
|
||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.norm_1 = LayerNorm(filter_channels)
|
||||
self.conv_2 = torch.nn.Conv1d(
|
||||
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.norm_2 = LayerNorm(filter_channels)
|
||||
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_2(x)
|
||||
x = self.drop(x)
|
||||
x = self.proj(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class RotaryPositionalEmbeddings(nn.Module):
|
||||
"""
|
||||
## RoPE module
|
||||
|
||||
Rotary encoding transforms pairs of features by rotating in the 2D plane.
|
||||
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
|
||||
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
|
||||
by an angle depending on the position of the token.
|
||||
"""
|
||||
|
||||
def __init__(self, d: int, base: int = 10_000):
|
||||
r"""
|
||||
* `d` is the number of features $d$
|
||||
* `base` is the constant used for calculating $\Theta$
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.base = base
|
||||
self.d = int(d)
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
|
||||
def _build_cache(self, x: torch.Tensor):
|
||||
r"""
|
||||
Cache $\cos$ and $\sin$ values
|
||||
"""
|
||||
# Return if cache is already built
|
||||
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
|
||||
return
|
||||
|
||||
# Get sequence length
|
||||
seq_len = x.shape[0]
|
||||
|
||||
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
||||
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(
|
||||
x.device
|
||||
)
|
||||
|
||||
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
||||
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
|
||||
|
||||
# Calculate the product of position index and $\theta_i$
|
||||
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
|
||||
|
||||
# Concatenate so that for row $m$ we have
|
||||
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
|
||||
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
|
||||
|
||||
# Cache them
|
||||
self.cos_cached = idx_theta2.cos()[:, None, None, :]
|
||||
self.sin_cached = idx_theta2.sin()[:, None, None, :]
|
||||
|
||||
def _neg_half(self, x: torch.Tensor):
|
||||
# $\frac{d}{2}$
|
||||
d_2 = self.d // 2
|
||||
|
||||
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
|
||||
"""
|
||||
# Cache $\cos$ and $\sin$ values
|
||||
x = rearrange(x, "b h t d -> t b h d")
|
||||
|
||||
self._build_cache(x)
|
||||
|
||||
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
|
||||
x_rope, x_pass = x[..., : self.d], x[..., self.d :]
|
||||
|
||||
# Calculate
|
||||
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
||||
neg_half_x = self._neg_half(x_rope)
|
||||
|
||||
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (
|
||||
neg_half_x * self.sin_cached[: x.shape[0]]
|
||||
)
|
||||
|
||||
return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
out_channels,
|
||||
n_heads,
|
||||
heads_share=True,
|
||||
p_dropout=0.0,
|
||||
proximal_bias=False,
|
||||
proximal_init=False,
|
||||
):
|
||||
super().__init__()
|
||||
assert channels % n_heads == 0
|
||||
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels
|
||||
self.n_heads = n_heads
|
||||
self.heads_share = heads_share
|
||||
self.proximal_bias = proximal_bias
|
||||
self.p_dropout = p_dropout
|
||||
self.attn = None
|
||||
|
||||
self.k_channels = channels // n_heads
|
||||
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
||||
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
||||
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
||||
|
||||
# from https://nn.labml.ai/transformers/rope/index.html
|
||||
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
||||
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
||||
|
||||
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
|
||||
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
if proximal_init:
|
||||
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
||||
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
||||
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask=None):
|
||||
b, d, t_s, t_t = (*key.size(), query.size(2))
|
||||
query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
|
||||
key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
|
||||
value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
|
||||
|
||||
query = self.query_rotary_pe(query)
|
||||
key = self.key_rotary_pe(key)
|
||||
|
||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
||||
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(
|
||||
device=scores.device, dtype=scores.dtype
|
||||
)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
||||
p_attn = self.drop(p_attn)
|
||||
output = torch.matmul(p_attn, value)
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
|
||||
return output, p_attn
|
||||
|
||||
@staticmethod
|
||||
def _attention_bias_proximal(length):
|
||||
r = torch.arange(length, dtype=torch.float32)
|
||||
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
||||
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.conv_1 = torch.nn.Conv1d(
|
||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.conv_2 = torch.nn.Conv1d(
|
||||
filter_channels, out_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size=1,
|
||||
p_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
self.attn_layers = torch.nn.ModuleList()
|
||||
self.norm_layers_1 = torch.nn.ModuleList()
|
||||
self.ffn_layers = torch.nn.ModuleList()
|
||||
self.norm_layers_2 = torch.nn.ModuleList()
|
||||
for _ in range(self.n_layers):
|
||||
self.attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
||||
)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(
|
||||
FFN(
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
kernel_size,
|
||||
p_dropout=p_dropout,
|
||||
)
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
for i in range(self.n_layers):
|
||||
x = x * x_mask
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_type,
|
||||
encoder_params,
|
||||
duration_predictor_params,
|
||||
n_vocab,
|
||||
n_spks=1,
|
||||
spk_emb_dim=128,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder_type = encoder_type
|
||||
self.n_vocab = n_vocab
|
||||
self.n_feats = encoder_params.n_feats
|
||||
self.n_channels = encoder_params.n_channels
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.n_spks = n_spks
|
||||
|
||||
self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
|
||||
torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
|
||||
|
||||
if encoder_params.prenet:
|
||||
self.prenet = ConvReluNorm(
|
||||
self.n_channels,
|
||||
self.n_channels,
|
||||
self.n_channels,
|
||||
kernel_size=5,
|
||||
n_layers=3,
|
||||
p_dropout=0.5,
|
||||
)
|
||||
else:
|
||||
self.prenet = lambda x, x_mask: x
|
||||
|
||||
self.encoder = Encoder(
|
||||
encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
||||
encoder_params.filter_channels,
|
||||
encoder_params.n_heads,
|
||||
encoder_params.n_layers,
|
||||
encoder_params.kernel_size,
|
||||
encoder_params.p_dropout,
|
||||
)
|
||||
|
||||
self.proj_m = torch.nn.Conv1d(
|
||||
self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1
|
||||
)
|
||||
self.proj_w = DurationPredictor(
|
||||
self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
||||
duration_predictor_params.filter_channels_dp,
|
||||
duration_predictor_params.kernel_size,
|
||||
duration_predictor_params.p_dropout,
|
||||
)
|
||||
|
||||
def forward(self, x, x_lengths, spks=None):
|
||||
"""Run forward pass to the transformer based encoder and duration predictor
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): text input
|
||||
shape: (batch_size, max_text_length)
|
||||
x_lengths (torch.Tensor): text input lengths
|
||||
shape: (batch_size,)
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
mu (torch.Tensor): average output of the encoder
|
||||
shape: (batch_size, n_feats, max_text_length)
|
||||
logw (torch.Tensor): log duration predicted by the duration predictor
|
||||
shape: (batch_size, 1, max_text_length)
|
||||
x_mask (torch.Tensor): mask for the text input
|
||||
shape: (batch_size, 1, max_text_length)
|
||||
"""
|
||||
x = self.emb(x) * math.sqrt(self.n_channels)
|
||||
x = torch.transpose(x, 1, -1)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
|
||||
x = self.prenet(x, x_mask)
|
||||
if self.n_spks > 1:
|
||||
x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
||||
x = self.encoder(x, x_mask)
|
||||
mu = self.proj_m(x) * x_mask
|
||||
|
||||
x_dp = torch.detach(x)
|
||||
logw = self.proj_w(x_dp, x_mask)
|
||||
|
||||
return mu, logw, x_mask
|
353
egs/ljspeech/TTS/matcha/models/components/transformer.py
Normal file
353
egs/ljspeech/TTS/matcha/models/components/transformer.py
Normal file
@ -0,0 +1,353 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.models.attention import (
|
||||
GEGLU,
|
||||
GELU,
|
||||
AdaLayerNorm,
|
||||
AdaLayerNormZero,
|
||||
ApproximateGELU,
|
||||
)
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.models.lora import LoRACompatibleLinear
|
||||
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
"""
|
||||
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
References:
|
||||
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snakebeta(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
alpha=1.0,
|
||||
alpha_trainable=True,
|
||||
alpha_logscale=True,
|
||||
):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||
alpha will be trained along with the rest of your model.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_features = (
|
||||
out_features if isinstance(out_features, list) else [out_features]
|
||||
)
|
||||
self.proj = LoRACompatibleLinear(in_features, out_features)
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||
"""
|
||||
x = self.proj(x)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(self.alpha)
|
||||
beta = torch.exp(self.beta)
|
||||
else:
|
||||
alpha = self.alpha
|
||||
beta = self.beta
|
||||
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
|
||||
torch.sin(x * alpha), 2
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input.
|
||||
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
||||
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
final_dropout: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim)
|
||||
if activation_fn == "gelu-approximate":
|
||||
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
||||
elif activation_fn == "geglu":
|
||||
act_fn = GEGLU(dim, inner_dim)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
act_fn = ApproximateGELU(dim, inner_dim)
|
||||
elif activation_fn == "snakebeta":
|
||||
act_fn = SnakeBeta(dim, inner_dim)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
self.net.append(act_fn)
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
||||
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
final_dropout: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
self.use_ada_layer_norm_zero = (
|
||||
num_embeds_ada_norm is not None
|
||||
) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (
|
||||
num_embeds_ada_norm is not None
|
||||
) and norm_type == "ada_norm"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if self.use_ada_layer_norm:
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
self.norm2 = (
|
||||
AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
if self.use_ada_layer_norm
|
||||
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim
|
||||
if not double_self_attention
|
||||
else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
# scale_qk=False, # uncomment this to not to use flash attention
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 1. Self-Attention
|
||||
if self.use_ada_layer_norm:
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
cross_attention_kwargs = (
|
||||
cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states
|
||||
if self.only_cross_attention
|
||||
else None,
|
||||
attention_mask=encoder_attention_mask
|
||||
if self.only_cross_attention
|
||||
else attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep)
|
||||
if self.use_ada_layer_norm
|
||||
else self.norm2(hidden_states)
|
||||
)
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states = (
|
||||
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
)
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
|
||||
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
||||
ff_output = torch.cat(
|
||||
[
|
||||
self.ff(hid_slice)
|
||||
for hid_slice in norm_hidden_states.chunk(
|
||||
num_chunks, dim=self._chunk_dim
|
||||
)
|
||||
],
|
||||
dim=self._chunk_dim,
|
||||
)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
295
egs/ljspeech/TTS/matcha/models/matcha_tts.py
Normal file
295
egs/ljspeech/TTS/matcha/models/matcha_tts.py
Normal file
@ -0,0 +1,295 @@
|
||||
import datetime as dt
|
||||
import math
|
||||
import random
|
||||
|
||||
import matcha.monotonic_align as monotonic_align
|
||||
import torch
|
||||
from matcha.model import (
|
||||
denormalize,
|
||||
duration_loss,
|
||||
fix_len_compatibility,
|
||||
generate_path,
|
||||
sequence_mask,
|
||||
)
|
||||
from matcha.models.components.flow_matching import CFM
|
||||
from matcha.models.components.text_encoder import TextEncoder
|
||||
|
||||
|
||||
class MatchaTTS(torch.nn.Module): # 🍵
|
||||
def __init__(
|
||||
self,
|
||||
n_vocab,
|
||||
n_spks,
|
||||
spk_emb_dim,
|
||||
n_feats,
|
||||
encoder,
|
||||
decoder,
|
||||
cfm,
|
||||
data_statistics,
|
||||
out_size,
|
||||
optimizer=None,
|
||||
scheduler=None,
|
||||
prior_loss=True,
|
||||
use_precomputed_durations=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# self.save_hyperparameters(logger=False)
|
||||
|
||||
self.n_vocab = n_vocab
|
||||
self.n_spks = n_spks
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.n_feats = n_feats
|
||||
self.out_size = out_size
|
||||
self.prior_loss = prior_loss
|
||||
self.use_precomputed_durations = use_precomputed_durations
|
||||
|
||||
if n_spks > 1:
|
||||
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
||||
|
||||
self.encoder = TextEncoder(
|
||||
encoder.encoder_type,
|
||||
encoder.encoder_params,
|
||||
encoder.duration_predictor_params,
|
||||
n_vocab,
|
||||
n_spks,
|
||||
spk_emb_dim,
|
||||
)
|
||||
|
||||
self.decoder = CFM(
|
||||
in_channels=2 * encoder.encoder_params.n_feats,
|
||||
out_channel=encoder.encoder_params.n_feats,
|
||||
cfm_params=cfm,
|
||||
decoder_params=decoder,
|
||||
n_spks=n_spks,
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
)
|
||||
|
||||
if data_statistics is not None:
|
||||
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
|
||||
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
|
||||
else:
|
||||
self.register_buffer("mel_mean", torch.tensor(0.0))
|
||||
self.register_buffer("mel_std", torch.tensor(1.0))
|
||||
|
||||
@torch.inference_mode()
|
||||
def synthesise(
|
||||
self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0
|
||||
):
|
||||
"""
|
||||
Generates mel-spectrogram from text. Returns:
|
||||
1. encoder outputs
|
||||
2. decoder outputs
|
||||
3. generated alignment
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
||||
shape: (batch_size, max_text_length)
|
||||
x_lengths (torch.Tensor): lengths of texts in batch.
|
||||
shape: (batch_size,)
|
||||
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
|
||||
temperature (float, optional): controls variance of terminal distribution.
|
||||
spks (bool, optional): speaker ids.
|
||||
shape: (batch_size,)
|
||||
length_scale (float, optional): controls speech pace.
|
||||
Increase value to slow down generated speech and vice versa.
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
||||
# Average mel spectrogram generated by the encoder
|
||||
"decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
||||
# Refined mel spectrogram improved by the CFM
|
||||
"attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
|
||||
# Alignment map between text and mel spectrogram
|
||||
"mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
||||
# Denormalized mel spectrogram
|
||||
"mel_lengths": torch.Tensor, shape: (batch_size,),
|
||||
# Lengths of mel spectrograms
|
||||
"rtf": float,
|
||||
# Real-time factor
|
||||
"""
|
||||
# For RTF computation
|
||||
t = dt.datetime.now()
|
||||
|
||||
if self.n_spks > 1:
|
||||
# Get speaker embedding
|
||||
spks = self.spk_emb(spks.long())
|
||||
|
||||
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
||||
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
|
||||
|
||||
w = torch.exp(logw) * x_mask
|
||||
w_ceil = torch.ceil(w) * length_scale
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_max_length = y_lengths.max()
|
||||
y_max_length_ = fix_len_compatibility(y_max_length)
|
||||
|
||||
# Using obtained durations `w` construct alignment map `attn`
|
||||
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
|
||||
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
||||
|
||||
# Align encoded text and get mu_y
|
||||
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
||||
mu_y = mu_y.transpose(1, 2)
|
||||
encoder_outputs = mu_y[:, :, :y_max_length]
|
||||
|
||||
# Generate sample tracing the probability flow
|
||||
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
|
||||
decoder_outputs = decoder_outputs[:, :, :y_max_length]
|
||||
|
||||
t = (dt.datetime.now() - t).total_seconds()
|
||||
rtf = t * 22050 / (decoder_outputs.shape[-1] * 256)
|
||||
|
||||
return {
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"decoder_outputs": decoder_outputs,
|
||||
"attn": attn[:, :, :y_max_length],
|
||||
"mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std),
|
||||
"mel_lengths": y_lengths,
|
||||
"rtf": rtf,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
x_lengths,
|
||||
y,
|
||||
y_lengths,
|
||||
spks=None,
|
||||
out_size=None,
|
||||
cond=None,
|
||||
durations=None,
|
||||
):
|
||||
"""
|
||||
Computes 3 losses:
|
||||
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
||||
2. prior loss: loss between mel-spectrogram and encoder outputs.
|
||||
3. flow matching loss: loss between mel-spectrogram and decoder outputs.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
||||
shape: (batch_size, max_text_length)
|
||||
x_lengths (torch.Tensor): lengths of texts in batch.
|
||||
shape: (batch_size,)
|
||||
y (torch.Tensor): batch of corresponding mel-spectrograms.
|
||||
shape: (batch_size, n_feats, max_mel_length)
|
||||
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
|
||||
shape: (batch_size,)
|
||||
out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
|
||||
Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
|
||||
spks (torch.Tensor, optional): speaker ids.
|
||||
shape: (batch_size,)
|
||||
"""
|
||||
if self.n_spks > 1:
|
||||
# Get speaker embedding
|
||||
spks = self.spk_emb(spks)
|
||||
|
||||
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
||||
mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
|
||||
y_max_length = y.shape[-1]
|
||||
|
||||
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
|
||||
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
||||
|
||||
if self.use_precomputed_durations:
|
||||
attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
|
||||
else:
|
||||
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
||||
with torch.no_grad():
|
||||
const = -0.5 * math.log(2 * math.pi) * self.n_feats
|
||||
factor = -0.5 * torch.ones(
|
||||
mu_x.shape, dtype=mu_x.dtype, device=mu_x.device
|
||||
)
|
||||
y_square = torch.matmul(factor.transpose(1, 2), y**2)
|
||||
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
|
||||
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
|
||||
log_prior = y_square - y_mu_double + mu_square + const
|
||||
|
||||
attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
|
||||
attn = attn.detach() # b, t_text, T_mel
|
||||
|
||||
# Compute loss between predicted log-scaled durations and those obtained from MAS
|
||||
# refered to as prior loss in the paper
|
||||
logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
|
||||
dur_loss = duration_loss(logw, logw_, x_lengths)
|
||||
|
||||
# Cut a small segment of mel-spectrogram in order to increase batch size
|
||||
# - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it
|
||||
# - Do not need this hack for Matcha-TTS, but it works with it as well
|
||||
if not isinstance(out_size, type(None)):
|
||||
max_offset = (y_lengths - out_size).clamp(0)
|
||||
offset_ranges = list(
|
||||
zip([0] * max_offset.shape[0], max_offset.cpu().numpy())
|
||||
)
|
||||
out_offset = torch.LongTensor(
|
||||
[
|
||||
torch.tensor(random.choice(range(start, end)) if end > start else 0)
|
||||
for start, end in offset_ranges
|
||||
]
|
||||
).to(y_lengths)
|
||||
attn_cut = torch.zeros(
|
||||
attn.shape[0],
|
||||
attn.shape[1],
|
||||
out_size,
|
||||
dtype=attn.dtype,
|
||||
device=attn.device,
|
||||
)
|
||||
y_cut = torch.zeros(
|
||||
y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device
|
||||
)
|
||||
|
||||
y_cut_lengths = []
|
||||
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
|
||||
y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0)
|
||||
y_cut_lengths.append(y_cut_length)
|
||||
cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
|
||||
y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
|
||||
attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
|
||||
|
||||
y_cut_lengths = torch.LongTensor(y_cut_lengths)
|
||||
y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
|
||||
|
||||
attn = attn_cut
|
||||
y = y_cut
|
||||
y_mask = y_cut_mask
|
||||
|
||||
# Align encoded text with mel-spectrogram and get mu_y segment
|
||||
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
||||
mu_y = mu_y.transpose(1, 2)
|
||||
|
||||
# Compute loss of the decoder
|
||||
diff_loss, _ = self.decoder.compute_loss(
|
||||
x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond
|
||||
)
|
||||
|
||||
if self.prior_loss:
|
||||
prior_loss = torch.sum(
|
||||
0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask
|
||||
)
|
||||
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
|
||||
else:
|
||||
prior_loss = 0
|
||||
|
||||
return dur_loss, prior_loss, diff_loss, attn
|
||||
|
||||
def get_losses(self, batch):
|
||||
x, x_lengths = batch["x"], batch["x_lengths"]
|
||||
y, y_lengths = batch["y"], batch["y_lengths"]
|
||||
spks = batch["spks"]
|
||||
|
||||
dur_loss, prior_loss, diff_loss, *_ = self(
|
||||
x=x,
|
||||
x_lengths=x_lengths,
|
||||
y=y,
|
||||
y_lengths=y_lengths,
|
||||
spks=spks,
|
||||
out_size=self.out_size,
|
||||
durations=batch["durations"],
|
||||
)
|
||||
return {
|
||||
"dur_loss": dur_loss,
|
||||
"prior_loss": prior_loss,
|
||||
"diff_loss": diff_loss,
|
||||
}
|
3
egs/ljspeech/TTS/matcha/monotonic_align/.gitignore
vendored
Normal file
3
egs/ljspeech/TTS/matcha/monotonic_align/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
build
|
||||
core.c
|
||||
*.so
|
23
egs/ljspeech/TTS/matcha/monotonic_align/__init__.py
Normal file
23
egs/ljspeech/TTS/matcha/monotonic_align/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copied from
|
||||
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/__init__.py
|
||||
import numpy as np
|
||||
import torch
|
||||
from matcha.monotonic_align.core import maximum_path_c
|
||||
|
||||
|
||||
def maximum_path(value, mask):
|
||||
"""Cython optimised version.
|
||||
value: [b, t_x, t_y]
|
||||
mask: [b, t_x, t_y]
|
||||
"""
|
||||
value = value * mask
|
||||
device = value.device
|
||||
dtype = value.dtype
|
||||
value = value.data.cpu().numpy().astype(np.float32)
|
||||
path = np.zeros_like(value).astype(np.int32)
|
||||
mask = mask.data.cpu().numpy()
|
||||
|
||||
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
|
||||
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
|
||||
maximum_path_c(path, value, t_x_max, t_y_max)
|
||||
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
49
egs/ljspeech/TTS/matcha/monotonic_align/core.pyx
Normal file
49
egs/ljspeech/TTS/matcha/monotonic_align/core.pyx
Normal file
@ -0,0 +1,49 @@
|
||||
# Copied from
|
||||
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/core.pyx
|
||||
import numpy as np
|
||||
|
||||
cimport cython
|
||||
cimport numpy as np
|
||||
|
||||
from cython.parallel import prange
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
|
||||
cdef int x
|
||||
cdef int y
|
||||
cdef float v_prev
|
||||
cdef float v_cur
|
||||
cdef float tmp
|
||||
cdef int index = t_x - 1
|
||||
|
||||
for y in range(t_y):
|
||||
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||
if x == y:
|
||||
v_cur = max_neg_val
|
||||
else:
|
||||
v_cur = value[x, y-1]
|
||||
if x == 0:
|
||||
if y == 0:
|
||||
v_prev = 0.
|
||||
else:
|
||||
v_prev = max_neg_val
|
||||
else:
|
||||
v_prev = value[x-1, y-1]
|
||||
value[x, y] = max(v_cur, v_prev) + value[x, y]
|
||||
|
||||
for y in range(t_y - 1, -1, -1):
|
||||
path[index, y] = 1
|
||||
if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
|
||||
index = index - 1
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
|
||||
cdef int b = values.shape[0]
|
||||
|
||||
cdef int i
|
||||
for i in prange(b, nogil=True):
|
||||
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
|
12
egs/ljspeech/TTS/matcha/monotonic_align/setup.py
Normal file
12
egs/ljspeech/TTS/matcha/monotonic_align/setup.py
Normal file
@ -0,0 +1,12 @@
|
||||
# Copied from
|
||||
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/setup.py
|
||||
from distutils.core import setup
|
||||
|
||||
import numpy
|
||||
from Cython.Build import cythonize
|
||||
|
||||
setup(
|
||||
name="monotonic_align",
|
||||
ext_modules=cythonize("core.pyx"),
|
||||
include_dirs=[numpy.get_include()],
|
||||
)
|
204
egs/ljspeech/TTS/matcha/onnx_pretrained.py
Executable file
204
egs/ljspeech/TTS/matcha/onnx_pretrained.py
Executable file
@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import logging
|
||||
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from inference import load_vocoder
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--acoustic-model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the acoustic model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocoder",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the vocoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-text",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The text to generate speech for",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-wav",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The filename of the wave to save the generated speech",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class OnnxHifiGANModel:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
for i in self.model.get_inputs():
|
||||
print(i)
|
||||
|
||||
print("-----")
|
||||
|
||||
for i in self.model.get_outputs():
|
||||
print(i)
|
||||
|
||||
def __call__(self, x: torch.tensor):
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x.shape[0] == 1, x.shape
|
||||
|
||||
audio = self.model.run(
|
||||
[self.model.get_outputs()[0].name],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(audio)
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 2
|
||||
|
||||
self.session_opts = session_opts
|
||||
self.tokenizer = Tokenizer("./data/tokens.txt")
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
for i in self.model.get_inputs():
|
||||
print(i)
|
||||
|
||||
print("-----")
|
||||
|
||||
for i in self.model.get_outputs():
|
||||
print(i)
|
||||
|
||||
def __call__(self, x: torch.tensor):
|
||||
assert x.ndim == 2, x.shape
|
||||
assert x.shape[0] == 1, x.shape
|
||||
|
||||
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
|
||||
print("x_lengths", x_lengths)
|
||||
print("x", x.shape)
|
||||
|
||||
temperature = torch.tensor([1.0], dtype=torch.float32)
|
||||
length_scale = torch.tensor([1.0], dtype=torch.float32)
|
||||
|
||||
mel = self.model.run(
|
||||
[self.model.get_outputs()[0].name],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: x_lengths.numpy(),
|
||||
self.model.get_inputs()[2].name: temperature.numpy(),
|
||||
self.model.get_inputs()[3].name: length_scale.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(mel)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
params = get_parser().parse_args()
|
||||
logging.info(vars(params))
|
||||
|
||||
model = OnnxModel(params.acoustic_model)
|
||||
vocoder = OnnxHifiGANModel(params.vocoder)
|
||||
text = params.input_text
|
||||
x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True)
|
||||
x = torch.tensor(x, dtype=torch.int64)
|
||||
|
||||
start_t = dt.datetime.now()
|
||||
mel = model(x)
|
||||
end_t = dt.datetime.now()
|
||||
|
||||
start_t2 = dt.datetime.now()
|
||||
audio = vocoder(mel)
|
||||
end_t2 = dt.datetime.now()
|
||||
|
||||
print("audio", audio.shape) # (1, 1, num_samples)
|
||||
audio = audio.squeeze()
|
||||
|
||||
t = (end_t - start_t).total_seconds()
|
||||
t2 = (end_t2 - start_t2).total_seconds()
|
||||
rtf_am = t * 22050 / audio.shape[-1]
|
||||
rtf_vocoder = t2 * 22050 / audio.shape[-1]
|
||||
print("RTF for acoustic model ", rtf_am)
|
||||
print("RTF for vocoder", rtf_vocoder)
|
||||
|
||||
# skip denoiser
|
||||
sf.write(params.output_wav, audio, 22050, "PCM_16")
|
||||
logging.info(f"Saved to {params.output_wav}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
||||
"""
|
||||
|
||||
|HifiGAN |RTF |#Parameters (M)|
|
||||
|----------|-----|---------------|
|
||||
|v1 |0.818| 13.926 |
|
||||
|v2 |0.101| 0.925 |
|
||||
|v3 |0.118| 1.462 |
|
||||
|
||||
|Num steps|Acoustic Model RTF|
|
||||
|---------|------------------|
|
||||
| 2 | 0.039 |
|
||||
| 3 | 0.047 |
|
||||
| 4 | 0.071 |
|
||||
| 5 | 0.076 |
|
||||
| 6 | 0.103 |
|
||||
|
||||
"""
|
3
egs/ljspeech/TTS/matcha/requirements.txt
Normal file
3
egs/ljspeech/TTS/matcha/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
conformer==0.3.2
|
||||
diffusers # developed using version ==0.25.0
|
||||
librosa
|
1
egs/ljspeech/TTS/matcha/tokenizer.py
Symbolic link
1
egs/ljspeech/TTS/matcha/tokenizer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../vits/tokenizer.py
|
723
egs/ljspeech/TTS/matcha/train.py
Executable file
723
egs/ljspeech/TTS/matcha/train.py
Executable file
@ -0,0 +1,723 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lhotse.utils import fix_random_seed
|
||||
from matcha.model import fix_len_compatibility
|
||||
from matcha.models.matcha_tts import MatchaTTS
|
||||
from matcha.tokenizer import Tokenizer
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tts_datamodule import LJSpeechTtsDataModule
|
||||
from utils import MetricsTracker
|
||||
|
||||
from icefall.checkpoint import load_checkpoint, save_checkpoint
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12335,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Resume training from this epoch. It should be positive.
|
||||
If larger than 1, it will load checkpoint from
|
||||
exp-dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default="matcha/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/tokens.txt",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cmvn",
|
||||
type=str,
|
||||
default="data/fbank/cmvn.json",
|
||||
help="""Path to vocabulary.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="The seed for random generators intended for reproducibility",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=10,
|
||||
help="""Save checkpoint after processing this number of epochs"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.cur_epoch % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
|
||||
Since it will take around 1000 epochs, we suggest using a large
|
||||
save_every_n to save disk space.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_data_statistics():
|
||||
return AttributeDict(
|
||||
{
|
||||
"mel_mean": 0,
|
||||
"mel_std": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_data_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"name": "ljspeech",
|
||||
"train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt",
|
||||
"valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt",
|
||||
# "batch_size": 64,
|
||||
# "num_workers": 1,
|
||||
# "pin_memory": False,
|
||||
"cleaners": ["english_cleaners2"],
|
||||
"add_blank": True,
|
||||
"n_spks": 1,
|
||||
"n_fft": 1024,
|
||||
"n_feats": 80,
|
||||
"sample_rate": 22050,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"f_min": 0,
|
||||
"f_max": 8000,
|
||||
"seed": 1234,
|
||||
"load_durations": False,
|
||||
"data_statistics": get_data_statistics(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def _get_model_params() -> AttributeDict:
|
||||
n_feats = 80
|
||||
filter_channels_dp = 256
|
||||
encoder_params_p_dropout = 0.1
|
||||
params = AttributeDict(
|
||||
{
|
||||
"n_spks": 1, # for ljspeech.
|
||||
"spk_emb_dim": 64,
|
||||
"n_feats": n_feats,
|
||||
"out_size": None, # or use 172
|
||||
"prior_loss": True,
|
||||
"use_precomputed_durations": False,
|
||||
"data_statistics": get_data_statistics(),
|
||||
"encoder": AttributeDict(
|
||||
{
|
||||
"encoder_type": "RoPE Encoder", # not used
|
||||
"encoder_params": AttributeDict(
|
||||
{
|
||||
"n_feats": n_feats,
|
||||
"n_channels": 192,
|
||||
"filter_channels": 768,
|
||||
"filter_channels_dp": filter_channels_dp,
|
||||
"n_heads": 2,
|
||||
"n_layers": 6,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": encoder_params_p_dropout,
|
||||
"spk_emb_dim": 64,
|
||||
"n_spks": 1,
|
||||
"prenet": True,
|
||||
}
|
||||
),
|
||||
"duration_predictor_params": AttributeDict(
|
||||
{
|
||||
"filter_channels_dp": filter_channels_dp,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": encoder_params_p_dropout,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
"decoder": AttributeDict(
|
||||
{
|
||||
"channels": [256, 256],
|
||||
"dropout": 0.05,
|
||||
"attention_head_dim": 64,
|
||||
"n_blocks": 1,
|
||||
"num_mid_blocks": 2,
|
||||
"num_heads": 2,
|
||||
"act_fn": "snakebeta",
|
||||
}
|
||||
),
|
||||
"cfm": AttributeDict(
|
||||
{
|
||||
"name": "CFM",
|
||||
"solver": "euler",
|
||||
"sigma_min": 1e-4,
|
||||
}
|
||||
),
|
||||
"optimizer": AttributeDict(
|
||||
{
|
||||
"lr": 1e-4,
|
||||
"weight_decay": 0.0,
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def get_params():
|
||||
params = AttributeDict(
|
||||
{
|
||||
"model_args": _get_model_params(),
|
||||
"data_args": _get_data_params(),
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": -1, # 0
|
||||
"log_interval": 10,
|
||||
"valid_interval": 1500,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def get_model(params):
|
||||
m = MatchaTTS(**params.model_args)
|
||||
return m
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict, model: nn.Module
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
Returns:
|
||||
Return a dict containing previously saved training info.
|
||||
"""
|
||||
if params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
||||
saved_params = load_checkpoint(filename, model=model)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params):
|
||||
"""Parse batch data"""
|
||||
mel_mean = params.data_args.data_statistics.mel_mean
|
||||
mel_std_inv = 1 / params.data_args.data_statistics.mel_std
|
||||
for i in range(batch["features"].shape[0]):
|
||||
n = batch["features_lens"][i]
|
||||
batch["features"][i : i + 1, :n, :] = (
|
||||
batch["features"][i : i + 1, :n, :] - mel_mean
|
||||
) * mel_std_inv
|
||||
batch["features"][i : i + 1, n:, :] = 0
|
||||
|
||||
audio = batch["audio"].to(device)
|
||||
features = batch["features"].to(device)
|
||||
audio_lens = batch["audio_lens"].to(device)
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
tokens = batch["tokens"]
|
||||
|
||||
tokens = tokenizer.tokens_to_token_ids(
|
||||
tokens, intersperse_blank=True, add_sos=True, add_eos=True
|
||||
)
|
||||
tokens = k2.RaggedTensor(tokens)
|
||||
row_splits = tokens.shape.row_splits(1)
|
||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||
tokens = tokens.to(device)
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
# a tensor of shape (B, T)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
||||
|
||||
max_feature_length = fix_len_compatibility(features.shape[1])
|
||||
if max_feature_length > features.shape[1]:
|
||||
pad = max_feature_length - features.shape[1]
|
||||
features = torch.nn.functional.pad(features, (0, 0, 0, pad))
|
||||
|
||||
# features_lens[features_lens.argmax()] += pad
|
||||
|
||||
return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long()
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
tokenizer: Tokenizer,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
|
||||
|
||||
# used to summary the stats over iterations
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
) = prepare_input(batch, tokenizer, device, params)
|
||||
|
||||
losses = get_losses(
|
||||
{
|
||||
"x": tokens,
|
||||
"x_lengths": tokens_lens,
|
||||
"y": features.permute(0, 2, 1),
|
||||
"y_lengths": features_lens,
|
||||
"spks": None, # should change it for multi-speakers
|
||||
"durations": None,
|
||||
}
|
||||
)
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
s = 0
|
||||
|
||||
for key, value in losses.items():
|
||||
v = value.detach().item()
|
||||
loss_info[key] = v * batch_size
|
||||
s += v * batch_size
|
||||
|
||||
loss_info["tot_loss"] = s
|
||||
|
||||
# summary stats
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(device)
|
||||
|
||||
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
tokenizer: Tokenizer,
|
||||
optimizer: Optimizer,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
"""
|
||||
model.train()
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses
|
||||
|
||||
# used to track the stats over iterations in one epoch
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
saved_bad_model = False
|
||||
|
||||
# used to track the stats over iterations in one epoch
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
saved_bad_model = False
|
||||
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint(
|
||||
filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scaler=scaler,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
# audio: (N, T), float32
|
||||
# features: (N, T, C), float32
|
||||
# audio_lens, (N,), int32
|
||||
# features_lens, (N,), int32
|
||||
# tokens: List[List[str]], len(tokens) == N
|
||||
|
||||
batch_size = len(batch["tokens"])
|
||||
|
||||
(
|
||||
audio,
|
||||
audio_lens,
|
||||
features,
|
||||
features_lens,
|
||||
tokens,
|
||||
tokens_lens,
|
||||
) = prepare_input(batch, tokenizer, device, params)
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
losses = get_losses(
|
||||
{
|
||||
"x": tokens,
|
||||
"x_lengths": tokens_lens,
|
||||
"y": features.permute(0, 2, 1),
|
||||
"y_lengths": features_lens,
|
||||
"spks": None, # should change it for multi-speakers
|
||||
"durations": None,
|
||||
}
|
||||
)
|
||||
|
||||
loss = sum(losses.values())
|
||||
|
||||
optimizer.zero_grad()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
s = 0
|
||||
|
||||
for key, value in losses.items():
|
||||
v = value.detach().item()
|
||||
loss_info[key] = v * batch_size
|
||||
s += v * batch_size
|
||||
|
||||
loss_info["tot_loss"] = s
|
||||
|
||||
tot_loss = tot_loss + loss_info
|
||||
except: # noqa
|
||||
save_bad_model()
|
||||
raise
|
||||
|
||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it.
|
||||
# The _growth_interval of the grad scaler is configurable,
|
||||
# but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
if cur_grad_scale < 8.0 or (
|
||||
cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0
|
||||
):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
save_bad_model(suffix="-first-warning")
|
||||
saved_bad_model = True
|
||||
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
||||
if cur_grad_scale < 1.0e-05:
|
||||
save_bad_model()
|
||||
raise RuntimeError(
|
||||
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
||||
)
|
||||
|
||||
if params.batch_idx_train % params.log_interval == 0:
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"global_batch_idx: {params.batch_idx_train}, "
|
||||
f"batch size: {batch_size}, "
|
||||
f"loss[{loss_info}], tot_loss[{tot_loss}], "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
|
||||
if params.batch_idx_train % params.valid_interval == 1:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
"Maximum memory allocated so far is "
|
||||
f"{torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
loss_value = tot_loss["tot_loss"] / tot_loss["samples"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.pad_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
params.model_args.n_vocab = params.vocab_size
|
||||
|
||||
with open(params.cmvn) as f:
|
||||
stats = json.load(f)
|
||||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.data_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
|
||||
params.model_args.data_statistics.mel_std = stats["fbank_std"]
|
||||
|
||||
logging.info(params)
|
||||
print(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of parameters: {num_param}")
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer)
|
||||
|
||||
logging.info("About to create datamodule")
|
||||
|
||||
ljspeech = LJSpeechTtsDataModule(args)
|
||||
|
||||
train_cuts = ljspeech.train_cuts()
|
||||
train_dl = ljspeech.train_dataloaders(train_cuts)
|
||||
|
||||
valid_cuts = ljspeech.valid_cuts()
|
||||
valid_dl = ljspeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
logging.info(f"Start epoch {epoch}")
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
if "sampler" in train_dl:
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
optimizer=optimizer,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
if rank == 0:
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LJSpeechTtsDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
341
egs/ljspeech/TTS/matcha/tts_datamodule.py
Normal file
341
egs/ljspeech/TTS/matcha/tts_datamodule.py
Normal file
@ -0,0 +1,341 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from compute_fbank_ljspeech import MyFbank, MyFbankConfig
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
SpeechSynthesisDataset,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LJSpeechTtsDataModule:
|
||||
"""
|
||||
DataModule for tts experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="TTS data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to create train dataset")
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MyFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
train = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MyFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
num_buckets=self.args.num_buckets,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create valid dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.info("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
sampling_rate = 22050
|
||||
config = MyFbankConfig(
|
||||
n_fft=1024,
|
||||
n_mels=80,
|
||||
sampling_rate=sampling_rate,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
)
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=OnTheFlyFeatures(MyFbank(config)),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=False,
|
||||
return_tokens=True,
|
||||
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
test_sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
num_buckets=self.args.num_buckets,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=test_sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get validation cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
|
||||
)
|
1
egs/ljspeech/TTS/matcha/utils.py
Symbolic link
1
egs/ljspeech/TTS/matcha/utils.py
Symbolic link
@ -0,0 +1 @@
|
||||
../vits/utils.py
|
@ -5,7 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=0
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
dl_dir=$PWD/download
|
||||
@ -31,7 +31,19 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
python3 setup.py build_ext --inplace
|
||||
cd ../../
|
||||
else
|
||||
log "monotonic_align lib already built"
|
||||
log "monotonic_align lib for vits already built"
|
||||
fi
|
||||
|
||||
if [ ! -f ./matcha/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then
|
||||
pushd matcha/monotonic_align
|
||||
python3 setup.py build
|
||||
mv -v build/lib.*/matcha/monotonic_align/core.*.so .
|
||||
rm -rf build
|
||||
rm core.c
|
||||
ls -lh
|
||||
popd
|
||||
else
|
||||
log "monotonic_align lib for matcha-tts already built"
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -63,7 +75,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Compute spectrogram for LJSpeech"
|
||||
log "Stage 2: Compute spectrogram for LJSpeech (used by ./vits)"
|
||||
mkdir -p data/spectrogram
|
||||
if [ ! -e data/spectrogram/.ljspeech.done ]; then
|
||||
./local/compute_spectrogram_ljspeech.py
|
||||
@ -71,7 +83,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
fi
|
||||
|
||||
if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then
|
||||
log "Validating data/spectrogram for LJSpeech"
|
||||
log "Validating data/spectrogram for LJSpeech (used by ./vits)"
|
||||
python3 ./local/validate_manifest.py \
|
||||
data/spectrogram/ljspeech_cuts_all.jsonl.gz
|
||||
touch data/spectrogram/.ljspeech-validated.done
|
||||
@ -79,13 +91,13 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare phoneme tokens for LJSpeech"
|
||||
log "Stage 3: Prepare phoneme tokens for LJSpeech (used by ./vits)"
|
||||
# We assume you have installed piper_phonemize and espnet_tts_frontend.
|
||||
# If not, please install them with:
|
||||
# - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html,
|
||||
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
|
||||
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
|
||||
./local/prepare_tokens_ljspeech.py
|
||||
./local/prepare_tokens_ljspeech.py --in-out-dir ./data/spectrogram
|
||||
mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
|
||||
data/spectrogram/ljspeech_cuts_all.jsonl.gz
|
||||
touch data/spectrogram/.ljspeech_with_token.done
|
||||
@ -93,7 +105,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Split the LJSpeech cuts into train, valid and test sets"
|
||||
log "Stage 4: Split the LJSpeech cuts into train, valid and test sets (used by vits)"
|
||||
if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
|
||||
lhotse subset --last 600 \
|
||||
data/spectrogram/ljspeech_cuts_all.jsonl.gz \
|
||||
@ -126,3 +138,63 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
./local/prepare_token_file.py --tokens data/tokens.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Generate fbank (used by ./matcha)"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.ljspeech.done ]; then
|
||||
./local/compute_fbank_ljspeech.py
|
||||
touch data/fbank/.ljspeech.done
|
||||
fi
|
||||
|
||||
if [ ! -e data/fbank/.ljspeech-validated.done ]; then
|
||||
log "Validating data/fbank for LJSpeech (used by ./matcha)"
|
||||
python3 ./local/validate_manifest.py \
|
||||
data/fbank/ljspeech_cuts_all.jsonl.gz
|
||||
touch data/fbank/.ljspeech-validated.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Prepare phoneme tokens for LJSpeech (used by ./matcha)"
|
||||
# We assume you have installed piper_phonemize and espnet_tts_frontend.
|
||||
# If not, please install them with:
|
||||
# - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html,
|
||||
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
|
||||
if [ ! -e data/fbank/.ljspeech_with_token.done ]; then
|
||||
./local/prepare_tokens_ljspeech.py --in-out-dir ./data/fbank
|
||||
mv data/fbank/ljspeech_cuts_with_tokens_all.jsonl.gz \
|
||||
data/fbank/ljspeech_cuts_all.jsonl.gz
|
||||
touch data/fbank/.ljspeech_with_token.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Split the LJSpeech cuts into train, valid and test sets (used by ./matcha)"
|
||||
if [ ! -e data/fbank/.ljspeech_split.done ]; then
|
||||
lhotse subset --last 600 \
|
||||
data/fbank/ljspeech_cuts_all.jsonl.gz \
|
||||
data/fbank/ljspeech_cuts_validtest.jsonl.gz
|
||||
lhotse subset --first 100 \
|
||||
data/fbank/ljspeech_cuts_validtest.jsonl.gz \
|
||||
data/fbank/ljspeech_cuts_valid.jsonl.gz
|
||||
lhotse subset --last 500 \
|
||||
data/fbank/ljspeech_cuts_validtest.jsonl.gz \
|
||||
data/fbank/ljspeech_cuts_test.jsonl.gz
|
||||
|
||||
rm data/fbank/ljspeech_cuts_validtest.jsonl.gz
|
||||
|
||||
n=$(( $(gunzip -c data/fbank/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 ))
|
||||
lhotse subset --first $n \
|
||||
data/fbank/ljspeech_cuts_all.jsonl.gz \
|
||||
data/fbank/ljspeech_cuts_train.jsonl.gz
|
||||
touch data/fbank/.ljspeech_split.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Compute fbank mean and std (used by ./matcha)"
|
||||
if [ ! -f ./data/fbank/cmvn.json ]; then
|
||||
./local/compute_fbank_statistics.py ./data/fbank/ljspeech_cuts_train.jsonl.gz ./data/fbank/cmvn.json
|
||||
fi
|
||||
fi
|
||||
|
@ -90,7 +90,7 @@ def save_checkpoint(
|
||||
|
||||
if params:
|
||||
for k, v in params.items():
|
||||
assert k not in checkpoint
|
||||
assert k not in checkpoint, k
|
||||
checkpoint[k] = v
|
||||
|
||||
torch.save(checkpoint, filename)
|
||||
|
Loading…
x
Reference in New Issue
Block a user