mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge branch 'k2-fsa:master' into dev/k2ssl
This commit is contained in:
commit
cf5fd1a2e0
@ -1 +1 @@
|
|||||||
../../../icefall/shared/
|
../../../icefall/shared
|
@ -17,6 +17,7 @@ class MatchaFbankConfig:
|
|||||||
win_length: int
|
win_length: int
|
||||||
f_min: float
|
f_min: float
|
||||||
f_max: float
|
f_max: float
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
|
||||||
@register_extractor
|
@register_extractor
|
||||||
@ -46,7 +47,7 @@ class MatchaFbank(FeatureExtractor):
|
|||||||
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
f"Mismatched sampling rate: extractor expects {expected_sr}, "
|
||||||
f"got {sampling_rate}"
|
f"got {sampling_rate}"
|
||||||
)
|
)
|
||||||
samples = torch.from_numpy(samples)
|
samples = torch.from_numpy(samples).to(self.device)
|
||||||
assert samples.ndim == 2, samples.shape
|
assert samples.ndim == 2, samples.shape
|
||||||
assert samples.shape[0] == 1, samples.shape
|
assert samples.shape[0] == 1, samples.shape
|
||||||
|
|
||||||
@ -81,7 +82,7 @@ class MatchaFbank(FeatureExtractor):
|
|||||||
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
|
|
||||||
return mel.numpy()
|
return mel.cpu().numpy()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def frame_shift(self) -> Seconds:
|
def frame_shift(self) -> Seconds:
|
||||||
|
@ -68,5 +68,69 @@ python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_
|
|||||||
--text-extractor pypinyin_initials_finals --top-p ${top_p}
|
--text-extractor pypinyin_initials_finals --top-p ${top_p}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# [F5-TTS](https://arxiv.org/abs/2410.06885)
|
||||||
|
|
||||||
|
./f5-tts contains the code for training F5-TTS model.
|
||||||
|
|
||||||
|
Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-small-wenetspeech4tts-basic/tensorboard).
|
||||||
|
|
||||||
|
Preparation:
|
||||||
|
|
||||||
|
```
|
||||||
|
bash prepare.sh --stage 5 --stop_stage 6
|
||||||
|
```
|
||||||
|
(Note: To compatiable with F5-TTS official checkpoint, we direclty use `vocab.txt` from [here.](https://github.com/SWivid/F5-TTS/blob/129014c5b43f135b0100d49a0c6804dd4cf673e1/data/Emilia_ZH_EN_pinyin/vocab.txt) To generate your own `vocab.txt`, you may refer to [the script](https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/train/datasets/prepare_emilia.py).)
|
||||||
|
|
||||||
|
The training command is given below:
|
||||||
|
|
||||||
|
```
|
||||||
|
# docker: ghcr.io/swivid/f5-tts:main
|
||||||
|
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||||
|
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
|
||||||
|
|
||||||
|
world_size=8
|
||||||
|
exp_dir=exp/f5-tts-small
|
||||||
|
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
|
||||||
|
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
|
||||||
|
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
|
||||||
|
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
|
||||||
|
--exp-dir ${exp_dir} --world-size ${world_size}
|
||||||
|
```
|
||||||
|
|
||||||
|
To inference with Icefall Wenetspeech4TTS trained F5-Small, use:
|
||||||
|
```
|
||||||
|
huggingface-cli login
|
||||||
|
huggingface-cli download --local-dir seed_tts_eval yuekai/seed_tts_eval --repo-type dataset
|
||||||
|
huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-small-wenetspeech4tts-basic
|
||||||
|
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
|
||||||
|
|
||||||
|
manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst
|
||||||
|
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-56-avg-14.pt
|
||||||
|
# skip
|
||||||
|
python3 f5-tts/generate_averaged_model.py \
|
||||||
|
--epoch 56 \
|
||||||
|
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
|
||||||
|
--exp-dir exp/f5_small
|
||||||
|
|
||||||
|
|
||||||
|
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
|
||||||
|
bash local/compute_wer.sh $output_dir $manifest
|
||||||
|
```
|
||||||
|
|
||||||
|
To inference with official Emilia trained F5-Base, use:
|
||||||
|
```
|
||||||
|
huggingface-cli login
|
||||||
|
huggingface-cli download --local-dir seed_tts_eval yuekai/seed_tts_eval --repo-type dataset
|
||||||
|
huggingface-cli download --local-dir F5-TTS SWivid/F5-TTS
|
||||||
|
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
|
||||||
|
|
||||||
|
manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst
|
||||||
|
model_path=./F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
|
||||||
|
|
||||||
|
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir
|
||||||
|
bash local/compute_wer.sh $output_dir $manifest
|
||||||
|
```
|
||||||
|
|
||||||
# Credits
|
# Credits
|
||||||
- [vall-e](https://github.com/lifeiteng/vall-e)
|
- [VALL-E](https://github.com/lifeiteng/vall-e)
|
||||||
|
- [F5-TTS](https://github.com/SWivid/F5-TTS)
|
||||||
|
173
egs/wenetspeech4tts/TTS/f5-tts/generate_averaged_model.py
Normal file
173
egs/wenetspeech4tts/TTS/f5-tts/generate_averaged_model.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
|
||||||
|
# Copyright 2024 Yuekai Zhang
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
(1) use the checkpoint exp_dir/epoch-xxx.pt
|
||||||
|
python3 bin/generate_averaged_model.py \
|
||||||
|
--epoch 40 \
|
||||||
|
--avg 5 \
|
||||||
|
--exp-dir ${exp_dir}
|
||||||
|
|
||||||
|
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
|
||||||
|
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
from train import add_model_arguments, get_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
)
|
||||||
|
from icefall.utils import AttributeDict
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=9,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
add_model_arguments(parser)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = AttributeDict()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"checkpoint-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
print("Script started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
print(f"Device: {device}")
|
||||||
|
|
||||||
|
print("About to create model")
|
||||||
|
filename = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
checkpoint = torch.load(filename, map_location=device)
|
||||||
|
args = AttributeDict(checkpoint)
|
||||||
|
model = get_model(args)
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
# TODO FIX ME
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
print(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
filename = params.exp_dir / f"checkpoint-{params.iter}-avg-{params.avg}.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, filename)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
print(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
filenames = [
|
||||||
|
f"{params.exp_dir}/epoch-{i}.pt" for i in range(start, params.epoch + 1)
|
||||||
|
]
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
|
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
|
checkpoint["model"] = model.state_dict()
|
||||||
|
torch.save(checkpoint, filename)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
364
egs/wenetspeech4tts/TTS/f5-tts/infer.py
Normal file
364
egs/wenetspeech4tts/TTS/f5-tts/infer.py
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
# docker: ghcr.io/swivid/f5-tts:main
|
||||||
|
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||||
|
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx
|
||||||
|
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
|
||||||
|
manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst
|
||||||
|
python3 f5-tts/generate_averaged_model.py \
|
||||||
|
--epoch 56 \
|
||||||
|
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
|
||||||
|
--exp-dir exp/f5_small
|
||||||
|
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
|
||||||
|
bash local/compute_wer.sh $output_dir $manifest
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from bigvganinference import BigVGANInference
|
||||||
|
from model.cfm import CFM
|
||||||
|
from model.dit import DiT
|
||||||
|
from model.modules import MelSpec
|
||||||
|
from model.utils import convert_char_to_pinyin
|
||||||
|
from tqdm import tqdm
|
||||||
|
from train import (
|
||||||
|
add_model_arguments,
|
||||||
|
get_model,
|
||||||
|
get_tokenizer,
|
||||||
|
load_F5_TTS_pretrained_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
default="f5-tts/vocab.txt",
|
||||||
|
help="Path to the unique text tokens file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-path",
|
||||||
|
type=str,
|
||||||
|
default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt",
|
||||||
|
help="Path to the unique text tokens file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="The seed for random generators intended for reproducibility",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nfe",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The number of steps for the neural ODE",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-file",
|
||||||
|
type=str,
|
||||||
|
default="/path/seed_tts_eval/seedtts_testset/zh/meta.lst",
|
||||||
|
help="The manifest file in seed_tts_eval format",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=Path,
|
||||||
|
default="results",
|
||||||
|
help="The output directory to save the generated wavs",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||||
|
add_model_arguments(parser)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_prompt(
|
||||||
|
metainfo,
|
||||||
|
speed=1.0,
|
||||||
|
tokenizer="pinyin",
|
||||||
|
polyphone=True,
|
||||||
|
target_sample_rate=24000,
|
||||||
|
n_fft=1024,
|
||||||
|
win_length=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
hop_length=256,
|
||||||
|
mel_spec_type="bigvgan",
|
||||||
|
target_rms=0.1,
|
||||||
|
use_truth_duration=False,
|
||||||
|
infer_batch_size=1,
|
||||||
|
num_buckets=200,
|
||||||
|
min_secs=3,
|
||||||
|
max_secs=40,
|
||||||
|
):
|
||||||
|
prompts_all = []
|
||||||
|
|
||||||
|
min_tokens = min_secs * target_sample_rate // hop_length
|
||||||
|
max_tokens = max_secs * target_sample_rate // hop_length
|
||||||
|
|
||||||
|
batch_accum = [0] * num_buckets
|
||||||
|
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
||||||
|
[[] for _ in range(num_buckets)] for _ in range(6)
|
||||||
|
)
|
||||||
|
|
||||||
|
mel_spectrogram = MelSpec(
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
n_mel_channels=n_mel_channels,
|
||||||
|
target_sample_rate=target_sample_rate,
|
||||||
|
mel_spec_type=mel_spec_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
|
||||||
|
metainfo, desc="Processing prompts..."
|
||||||
|
):
|
||||||
|
# Audio
|
||||||
|
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
||||||
|
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
||||||
|
if ref_rms < target_rms:
|
||||||
|
ref_audio = ref_audio * target_rms / ref_rms
|
||||||
|
assert (
|
||||||
|
ref_audio.shape[-1] > 5000
|
||||||
|
), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
|
||||||
|
if ref_sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||||
|
ref_audio = resampler(ref_audio)
|
||||||
|
|
||||||
|
# Text
|
||||||
|
if len(prompt_text[-1].encode("utf-8")) == 1:
|
||||||
|
prompt_text = prompt_text + " "
|
||||||
|
text = [prompt_text + gt_text]
|
||||||
|
if tokenizer == "pinyin":
|
||||||
|
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
|
||||||
|
else:
|
||||||
|
text_list = text
|
||||||
|
|
||||||
|
# Duration, mel frame length
|
||||||
|
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||||
|
if use_truth_duration:
|
||||||
|
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
||||||
|
if gt_sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
|
||||||
|
gt_audio = resampler(gt_audio)
|
||||||
|
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
|
||||||
|
|
||||||
|
# # test vocoder resynthesis
|
||||||
|
# ref_audio = gt_audio
|
||||||
|
else:
|
||||||
|
ref_text_len = len(prompt_text.encode("utf-8"))
|
||||||
|
gen_text_len = len(gt_text.encode("utf-8"))
|
||||||
|
total_mel_len = ref_mel_len + int(
|
||||||
|
ref_mel_len / ref_text_len * gen_text_len / speed
|
||||||
|
)
|
||||||
|
|
||||||
|
# to mel spectrogram
|
||||||
|
ref_mel = mel_spectrogram(ref_audio)
|
||||||
|
ref_mel = ref_mel.squeeze(0)
|
||||||
|
|
||||||
|
# deal with batch
|
||||||
|
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||||
|
assert (
|
||||||
|
min_tokens <= total_mel_len <= max_tokens
|
||||||
|
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||||
|
bucket_i = math.floor(
|
||||||
|
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
|
||||||
|
)
|
||||||
|
|
||||||
|
utts[bucket_i].append(utt)
|
||||||
|
ref_rms_list[bucket_i].append(ref_rms)
|
||||||
|
ref_mels[bucket_i].append(ref_mel)
|
||||||
|
ref_mel_lens[bucket_i].append(ref_mel_len)
|
||||||
|
total_mel_lens[bucket_i].append(total_mel_len)
|
||||||
|
final_text_list[bucket_i].extend(text_list)
|
||||||
|
|
||||||
|
batch_accum[bucket_i] += total_mel_len
|
||||||
|
|
||||||
|
if batch_accum[bucket_i] >= infer_batch_size:
|
||||||
|
prompts_all.append(
|
||||||
|
(
|
||||||
|
utts[bucket_i],
|
||||||
|
ref_rms_list[bucket_i],
|
||||||
|
padded_mel_batch(ref_mels[bucket_i]),
|
||||||
|
ref_mel_lens[bucket_i],
|
||||||
|
total_mel_lens[bucket_i],
|
||||||
|
final_text_list[bucket_i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
batch_accum[bucket_i] = 0
|
||||||
|
(
|
||||||
|
utts[bucket_i],
|
||||||
|
ref_rms_list[bucket_i],
|
||||||
|
ref_mels[bucket_i],
|
||||||
|
ref_mel_lens[bucket_i],
|
||||||
|
total_mel_lens[bucket_i],
|
||||||
|
final_text_list[bucket_i],
|
||||||
|
) = (
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# add residual
|
||||||
|
for bucket_i, bucket_frames in enumerate(batch_accum):
|
||||||
|
if bucket_frames > 0:
|
||||||
|
prompts_all.append(
|
||||||
|
(
|
||||||
|
utts[bucket_i],
|
||||||
|
ref_rms_list[bucket_i],
|
||||||
|
padded_mel_batch(ref_mels[bucket_i]),
|
||||||
|
ref_mel_lens[bucket_i],
|
||||||
|
total_mel_lens[bucket_i],
|
||||||
|
final_text_list[bucket_i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# not only leave easy work for last workers
|
||||||
|
random.seed(666)
|
||||||
|
random.shuffle(prompts_all)
|
||||||
|
|
||||||
|
return prompts_all
|
||||||
|
|
||||||
|
|
||||||
|
def padded_mel_batch(ref_mels):
|
||||||
|
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
||||||
|
padded_ref_mels = []
|
||||||
|
for mel in ref_mels:
|
||||||
|
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
|
||||||
|
padded_ref_mels.append(padded_ref_mel)
|
||||||
|
padded_ref_mels = torch.stack(padded_ref_mels)
|
||||||
|
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
|
||||||
|
return padded_ref_mels
|
||||||
|
|
||||||
|
|
||||||
|
def get_seedtts_testset_metainfo(metalst):
|
||||||
|
f = open(metalst)
|
||||||
|
lines = f.readlines()
|
||||||
|
f.close()
|
||||||
|
metainfo = []
|
||||||
|
for line in lines:
|
||||||
|
assert len(line.strip().split("|")) == 4
|
||||||
|
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
|
||||||
|
utt = Path(utt).stem
|
||||||
|
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
||||||
|
if not os.path.isabs(prompt_wav):
|
||||||
|
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
||||||
|
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
|
||||||
|
return metainfo
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_parser()
|
||||||
|
|
||||||
|
accelerator = Accelerator()
|
||||||
|
device = f"cuda:{accelerator.process_index}"
|
||||||
|
|
||||||
|
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
|
||||||
|
prompts_all = get_inference_prompt(
|
||||||
|
metainfo,
|
||||||
|
speed=1.0,
|
||||||
|
tokenizer="pinyin",
|
||||||
|
target_sample_rate=24_000,
|
||||||
|
n_mel_channels=100,
|
||||||
|
hop_length=256,
|
||||||
|
mel_spec_type="bigvgan",
|
||||||
|
target_rms=0.1,
|
||||||
|
use_truth_duration=False,
|
||||||
|
infer_batch_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
vocoder = BigVGANInference.from_pretrained(
|
||||||
|
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
|
||||||
|
)
|
||||||
|
vocoder = vocoder.eval().to(device)
|
||||||
|
|
||||||
|
model = get_model(args).eval().to(device)
|
||||||
|
checkpoint = torch.load(args.model_path, map_location="cpu")
|
||||||
|
if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint:
|
||||||
|
model = load_F5_TTS_pretrained_checkpoint(model, args.model_path)
|
||||||
|
else:
|
||||||
|
_ = load_checkpoint(
|
||||||
|
args.model_path,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
with accelerator.split_between_processes(prompts_all) as prompts:
|
||||||
|
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
|
||||||
|
(
|
||||||
|
utts,
|
||||||
|
ref_rms_list,
|
||||||
|
ref_mels,
|
||||||
|
ref_mel_lens,
|
||||||
|
total_mel_lens,
|
||||||
|
final_text_list,
|
||||||
|
) = prompt
|
||||||
|
ref_mels = ref_mels.to(device)
|
||||||
|
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
||||||
|
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
with torch.inference_mode():
|
||||||
|
generated, _ = model.sample(
|
||||||
|
cond=ref_mels,
|
||||||
|
text=final_text_list,
|
||||||
|
duration=total_mel_lens,
|
||||||
|
lens=ref_mel_lens,
|
||||||
|
steps=args.nfe,
|
||||||
|
cfg_strength=2.0,
|
||||||
|
sway_sampling_coef=args.swaysampling,
|
||||||
|
no_ref_audio=False,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
for i, gen in enumerate(generated):
|
||||||
|
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
||||||
|
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
||||||
|
|
||||||
|
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
|
||||||
|
target_rms = 0.1
|
||||||
|
target_sample_rate = 24_000
|
||||||
|
if ref_rms_list[i] < target_rms:
|
||||||
|
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
||||||
|
torchaudio.save(
|
||||||
|
f"{args.output_dir}/{utts[i]}.wav",
|
||||||
|
generated_wave,
|
||||||
|
target_sample_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
timediff = time.time() - start
|
||||||
|
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
3
egs/wenetspeech4tts/TTS/f5-tts/model/README.md
Normal file
3
egs/wenetspeech4tts/TTS/f5-tts/model/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# Introduction
|
||||||
|
Files in this folder are copied from
|
||||||
|
https://github.com/SWivid/F5-TTS/tree/main/src/f5_tts/model
|
326
egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py
Normal file
326
egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py
Normal file
@ -0,0 +1,326 @@
|
|||||||
|
"""
|
||||||
|
ein notation:
|
||||||
|
b - batch
|
||||||
|
n - sequence
|
||||||
|
nt - text sequence
|
||||||
|
nw - raw wave length
|
||||||
|
d - dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from random import random
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from model.modules import MelSpec
|
||||||
|
from model.utils import (
|
||||||
|
default,
|
||||||
|
exists,
|
||||||
|
lens_to_mask,
|
||||||
|
list_str_to_idx,
|
||||||
|
list_str_to_tensor,
|
||||||
|
mask_from_frac_lengths,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from torchdiffeq import odeint
|
||||||
|
|
||||||
|
|
||||||
|
class CFM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transformer: nn.Module,
|
||||||
|
sigma=0.0,
|
||||||
|
odeint_kwargs: dict = dict(
|
||||||
|
# atol = 1e-5,
|
||||||
|
# rtol = 1e-5,
|
||||||
|
method="euler" # 'midpoint'
|
||||||
|
),
|
||||||
|
audio_drop_prob=0.3,
|
||||||
|
cond_drop_prob=0.2,
|
||||||
|
num_channels=None,
|
||||||
|
mel_spec_module: nn.Module | None = None,
|
||||||
|
mel_spec_kwargs: dict = dict(),
|
||||||
|
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
|
||||||
|
vocab_char_map: dict[str:int] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.frac_lengths_mask = frac_lengths_mask
|
||||||
|
|
||||||
|
# mel spec
|
||||||
|
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
|
||||||
|
num_channels = default(num_channels, self.mel_spec.n_mel_channels)
|
||||||
|
self.num_channels = num_channels
|
||||||
|
|
||||||
|
# classifier-free guidance
|
||||||
|
self.audio_drop_prob = audio_drop_prob
|
||||||
|
self.cond_drop_prob = cond_drop_prob
|
||||||
|
|
||||||
|
# transformer
|
||||||
|
self.transformer = transformer
|
||||||
|
dim = transformer.dim
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
# conditional flow related
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
|
# sampling related
|
||||||
|
self.odeint_kwargs = odeint_kwargs
|
||||||
|
|
||||||
|
# vocab map for tokenization
|
||||||
|
self.vocab_char_map = vocab_char_map
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
cond: float["b n d"] | float["b nw"], # noqa: F722
|
||||||
|
text: int["b nt"] | list[str], # noqa: F722
|
||||||
|
duration: int | int["b"], # noqa: F821
|
||||||
|
*,
|
||||||
|
lens: int["b"] | None = None, # noqa: F821
|
||||||
|
steps=32,
|
||||||
|
cfg_strength=1.0,
|
||||||
|
sway_sampling_coef=None,
|
||||||
|
seed: int | None = None,
|
||||||
|
max_duration=4096,
|
||||||
|
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
||||||
|
no_ref_audio=False,
|
||||||
|
duplicate_test=False,
|
||||||
|
t_inter=0.1,
|
||||||
|
edit_mask=None,
|
||||||
|
):
|
||||||
|
self.eval()
|
||||||
|
# raw wave
|
||||||
|
|
||||||
|
if cond.ndim == 2:
|
||||||
|
cond = self.mel_spec(cond)
|
||||||
|
cond = cond.permute(0, 2, 1)
|
||||||
|
assert cond.shape[-1] == self.num_channels
|
||||||
|
|
||||||
|
cond = cond.to(next(self.parameters()).dtype)
|
||||||
|
|
||||||
|
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
||||||
|
if not exists(lens):
|
||||||
|
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
# text
|
||||||
|
|
||||||
|
if isinstance(text, list):
|
||||||
|
if exists(self.vocab_char_map):
|
||||||
|
text = list_str_to_idx(text, self.vocab_char_map).to(device)
|
||||||
|
else:
|
||||||
|
text = list_str_to_tensor(text).to(device)
|
||||||
|
assert text.shape[0] == batch
|
||||||
|
|
||||||
|
if exists(text):
|
||||||
|
text_lens = (text != -1).sum(dim=-1)
|
||||||
|
lens = torch.maximum(
|
||||||
|
text_lens, lens
|
||||||
|
) # make sure lengths are at least those of the text characters
|
||||||
|
|
||||||
|
# duration
|
||||||
|
|
||||||
|
cond_mask = lens_to_mask(lens)
|
||||||
|
if edit_mask is not None:
|
||||||
|
cond_mask = cond_mask & edit_mask
|
||||||
|
|
||||||
|
if isinstance(duration, int):
|
||||||
|
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
duration = torch.maximum(
|
||||||
|
lens + 1, duration
|
||||||
|
) # just add one token so something is generated
|
||||||
|
duration = duration.clamp(max=max_duration)
|
||||||
|
max_duration = duration.amax()
|
||||||
|
|
||||||
|
# duplicate test corner for inner time step oberservation
|
||||||
|
if duplicate_test:
|
||||||
|
test_cond = F.pad(
|
||||||
|
cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
||||||
|
cond_mask = F.pad(
|
||||||
|
cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
|
||||||
|
)
|
||||||
|
cond_mask = cond_mask.unsqueeze(-1)
|
||||||
|
step_cond = torch.where(
|
||||||
|
cond_mask, cond, torch.zeros_like(cond)
|
||||||
|
) # allow direct control (cut cond audio) with lens passed in
|
||||||
|
|
||||||
|
if batch > 1:
|
||||||
|
mask = lens_to_mask(duration)
|
||||||
|
else: # save memory and speed up, as single inference need no mask currently
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
# test for no ref audio
|
||||||
|
if no_ref_audio:
|
||||||
|
cond = torch.zeros_like(cond)
|
||||||
|
|
||||||
|
# neural ode
|
||||||
|
|
||||||
|
def fn(t, x):
|
||||||
|
# at each step, conditioning is fixed
|
||||||
|
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
||||||
|
|
||||||
|
# predict flow
|
||||||
|
pred = self.transformer(
|
||||||
|
x=x,
|
||||||
|
cond=step_cond,
|
||||||
|
text=text,
|
||||||
|
time=t,
|
||||||
|
mask=mask,
|
||||||
|
drop_audio_cond=False,
|
||||||
|
drop_text=False,
|
||||||
|
)
|
||||||
|
if cfg_strength < 1e-5:
|
||||||
|
return pred
|
||||||
|
|
||||||
|
null_pred = self.transformer(
|
||||||
|
x=x,
|
||||||
|
cond=step_cond,
|
||||||
|
text=text,
|
||||||
|
time=t,
|
||||||
|
mask=mask,
|
||||||
|
drop_audio_cond=True,
|
||||||
|
drop_text=True,
|
||||||
|
)
|
||||||
|
return pred + (pred - null_pred) * cfg_strength
|
||||||
|
|
||||||
|
# noise input
|
||||||
|
# to make sure batch inference result is same with different batch size, and for sure single inference
|
||||||
|
# still some difference maybe due to convolutional layers
|
||||||
|
y0 = []
|
||||||
|
for dur in duration:
|
||||||
|
if exists(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
y0.append(
|
||||||
|
torch.randn(
|
||||||
|
dur, self.num_channels, device=self.device, dtype=step_cond.dtype
|
||||||
|
)
|
||||||
|
)
|
||||||
|
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
|
||||||
|
|
||||||
|
t_start = 0
|
||||||
|
|
||||||
|
# duplicate test corner for inner time step oberservation
|
||||||
|
if duplicate_test:
|
||||||
|
t_start = t_inter
|
||||||
|
y0 = (1 - t_start) * y0 + t_start * test_cond
|
||||||
|
steps = int(steps * (1 - t_start))
|
||||||
|
|
||||||
|
t = torch.linspace(
|
||||||
|
t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype
|
||||||
|
)
|
||||||
|
if sway_sampling_coef is not None:
|
||||||
|
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
||||||
|
|
||||||
|
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
||||||
|
|
||||||
|
sampled = trajectory[-1]
|
||||||
|
out = sampled
|
||||||
|
out = torch.where(cond_mask, cond, out)
|
||||||
|
|
||||||
|
if exists(vocoder):
|
||||||
|
out = out.permute(0, 2, 1)
|
||||||
|
out = vocoder(out)
|
||||||
|
|
||||||
|
return out, trajectory
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
|
||||||
|
text: int["b nt"] | list[str], # noqa: F722
|
||||||
|
*,
|
||||||
|
lens: int["b"] | None = None, # noqa: F821
|
||||||
|
noise_scheduler: str | None = None,
|
||||||
|
):
|
||||||
|
# handle raw wave
|
||||||
|
if inp.ndim == 2:
|
||||||
|
inp = self.mel_spec(inp)
|
||||||
|
inp = inp.permute(0, 2, 1)
|
||||||
|
assert inp.shape[-1] == self.num_channels
|
||||||
|
|
||||||
|
batch, seq_len, dtype, device, _σ1 = (
|
||||||
|
*inp.shape[:2],
|
||||||
|
inp.dtype,
|
||||||
|
self.device,
|
||||||
|
self.sigma,
|
||||||
|
)
|
||||||
|
|
||||||
|
# handle text as string
|
||||||
|
if isinstance(text, list):
|
||||||
|
if exists(self.vocab_char_map):
|
||||||
|
text = list_str_to_idx(text, self.vocab_char_map).to(device)
|
||||||
|
else:
|
||||||
|
text = list_str_to_tensor(text).to(device)
|
||||||
|
assert text.shape[0] == batch
|
||||||
|
|
||||||
|
# lens and mask
|
||||||
|
if not exists(lens):
|
||||||
|
lens = torch.full((batch,), seq_len, device=device)
|
||||||
|
|
||||||
|
mask = lens_to_mask(
|
||||||
|
lens, length=seq_len
|
||||||
|
) # useless here, as collate_fn will pad to max length in batch
|
||||||
|
|
||||||
|
# get a random span to mask out for training conditionally
|
||||||
|
frac_lengths = (
|
||||||
|
torch.zeros((batch,), device=self.device)
|
||||||
|
.float()
|
||||||
|
.uniform_(*self.frac_lengths_mask)
|
||||||
|
)
|
||||||
|
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
rand_span_mask &= mask
|
||||||
|
|
||||||
|
# mel is x1
|
||||||
|
x1 = inp
|
||||||
|
|
||||||
|
# x0 is gaussian noise
|
||||||
|
x0 = torch.randn_like(x1)
|
||||||
|
|
||||||
|
# time step
|
||||||
|
time = torch.rand((batch,), dtype=dtype, device=self.device)
|
||||||
|
# TODO. noise_scheduler
|
||||||
|
|
||||||
|
# sample xt (φ_t(x) in the paper)
|
||||||
|
t = time.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
φ = (1 - t) * x0 + t * x1
|
||||||
|
flow = x1 - x0
|
||||||
|
|
||||||
|
# only predict what is within the random mask span for infilling
|
||||||
|
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
|
||||||
|
|
||||||
|
# transformer and cfg training with a drop rate
|
||||||
|
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
|
||||||
|
if random() < self.cond_drop_prob: # p_uncond in voicebox paper
|
||||||
|
drop_audio_cond = True
|
||||||
|
drop_text = True
|
||||||
|
else:
|
||||||
|
drop_text = False
|
||||||
|
|
||||||
|
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
|
||||||
|
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
||||||
|
pred = self.transformer(
|
||||||
|
x=φ,
|
||||||
|
cond=cond,
|
||||||
|
text=text,
|
||||||
|
time=time,
|
||||||
|
drop_audio_cond=drop_audio_cond,
|
||||||
|
drop_text=drop_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
# flow matching loss
|
||||||
|
loss = F.mse_loss(pred, flow, reduction="none")
|
||||||
|
loss = loss[rand_span_mask]
|
||||||
|
|
||||||
|
return loss.mean(), cond, pred
|
210
egs/wenetspeech4tts/TTS/f5-tts/model/dit.py
Normal file
210
egs/wenetspeech4tts/TTS/f5-tts/model/dit.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
"""
|
||||||
|
ein notation:
|
||||||
|
b - batch
|
||||||
|
n - sequence
|
||||||
|
nt - text sequence
|
||||||
|
nw - raw wave length
|
||||||
|
d - dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from model.modules import (
|
||||||
|
AdaLayerNormZero_Final,
|
||||||
|
ConvNeXtV2Block,
|
||||||
|
ConvPositionEmbedding,
|
||||||
|
DiTBlock,
|
||||||
|
TimestepEmbedding,
|
||||||
|
get_pos_embed_indices,
|
||||||
|
precompute_freqs_cis,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
from x_transformers.x_transformers import RotaryEmbedding
|
||||||
|
|
||||||
|
# Text embedding
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedding(nn.Module):
|
||||||
|
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
||||||
|
super().__init__()
|
||||||
|
self.text_embed = nn.Embedding(
|
||||||
|
text_num_embeds + 1, text_dim
|
||||||
|
) # use 0 as filler token
|
||||||
|
|
||||||
|
if conv_layers > 0:
|
||||||
|
self.extra_modeling = True
|
||||||
|
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||||
|
self.register_buffer(
|
||||||
|
"freqs_cis",
|
||||||
|
precompute_freqs_cis(text_dim, self.precompute_max_pos),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
self.text_blocks = nn.Sequential(
|
||||||
|
*[
|
||||||
|
ConvNeXtV2Block(text_dim, text_dim * conv_mult)
|
||||||
|
for _ in range(conv_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.extra_modeling = False
|
||||||
|
|
||||||
|
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
||||||
|
text = (
|
||||||
|
text + 1
|
||||||
|
) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||||
|
text = text[
|
||||||
|
:, :seq_len
|
||||||
|
] # curtail if character tokens are more than the mel spec tokens
|
||||||
|
batch, text_len = text.shape[0], text.shape[1]
|
||||||
|
text = F.pad(text, (0, seq_len - text_len), value=0)
|
||||||
|
|
||||||
|
if drop_text: # cfg for text
|
||||||
|
text = torch.zeros_like(text)
|
||||||
|
|
||||||
|
text = self.text_embed(text) # b n -> b n d
|
||||||
|
|
||||||
|
# possible extra modeling
|
||||||
|
if self.extra_modeling:
|
||||||
|
# sinus pos emb
|
||||||
|
batch_start = torch.zeros((batch,), dtype=torch.long)
|
||||||
|
pos_idx = get_pos_embed_indices(
|
||||||
|
batch_start, seq_len, max_pos=self.precompute_max_pos
|
||||||
|
)
|
||||||
|
text_pos_embed = self.freqs_cis[pos_idx]
|
||||||
|
text = text + text_pos_embed
|
||||||
|
|
||||||
|
# convnextv2 blocks
|
||||||
|
text = self.text_blocks(text)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# noised input audio and context mixing embedding
|
||||||
|
|
||||||
|
|
||||||
|
class InputEmbedding(nn.Module):
|
||||||
|
def __init__(self, mel_dim, text_dim, out_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
||||||
|
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: float["b n d"], # noqa: F722
|
||||||
|
cond: float["b n d"], # noqa: F722
|
||||||
|
text_embed: float["b n d"], # noqa: F722
|
||||||
|
drop_audio_cond=False,
|
||||||
|
):
|
||||||
|
if drop_audio_cond: # cfg for cond audio
|
||||||
|
cond = torch.zeros_like(cond)
|
||||||
|
|
||||||
|
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
||||||
|
x = self.conv_pos_embed(x) + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Transformer backbone using DiT blocks
|
||||||
|
|
||||||
|
|
||||||
|
class DiT(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
depth=8,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.1,
|
||||||
|
ff_mult=4,
|
||||||
|
mel_dim=100,
|
||||||
|
text_num_embeds=256,
|
||||||
|
text_dim=None,
|
||||||
|
conv_layers=0,
|
||||||
|
long_skip_connection=False,
|
||||||
|
checkpoint_activations=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.time_embed = TimestepEmbedding(dim)
|
||||||
|
if text_dim is None:
|
||||||
|
text_dim = mel_dim
|
||||||
|
self.text_embed = TextEmbedding(
|
||||||
|
text_num_embeds, text_dim, conv_layers=conv_layers
|
||||||
|
)
|
||||||
|
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
||||||
|
|
||||||
|
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DiTBlock(
|
||||||
|
dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
ff_mult=ff_mult,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
for _ in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.long_skip_connection = (
|
||||||
|
nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||||
|
self.proj_out = nn.Linear(dim, mel_dim)
|
||||||
|
|
||||||
|
self.checkpoint_activations = checkpoint_activations
|
||||||
|
|
||||||
|
def ckpt_wrapper(self, module):
|
||||||
|
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
||||||
|
def ckpt_forward(*inputs):
|
||||||
|
outputs = module(*inputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return ckpt_forward
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: float["b n d"], # nosied input audio # noqa: F722
|
||||||
|
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||||
|
text: int["b nt"], # text # noqa: F722
|
||||||
|
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||||
|
drop_audio_cond, # cfg for cond audio
|
||||||
|
drop_text, # cfg for text
|
||||||
|
mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
):
|
||||||
|
batch, seq_len = x.shape[0], x.shape[1]
|
||||||
|
if time.ndim == 0:
|
||||||
|
time = time.repeat(batch)
|
||||||
|
|
||||||
|
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||||
|
t = self.time_embed(time)
|
||||||
|
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||||
|
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||||
|
|
||||||
|
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||||
|
|
||||||
|
if self.long_skip_connection is not None:
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
if self.checkpoint_activations:
|
||||||
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
|
self.ckpt_wrapper(block), x, t, mask, rope
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = block(x, t, mask=mask, rope=rope)
|
||||||
|
|
||||||
|
if self.long_skip_connection is not None:
|
||||||
|
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
||||||
|
|
||||||
|
x = self.norm_out(x, t)
|
||||||
|
output = self.proj_out(x)
|
||||||
|
|
||||||
|
return output
|
728
egs/wenetspeech4tts/TTS/f5-tts/model/modules.py
Normal file
728
egs/wenetspeech4tts/TTS/f5-tts/model/modules.py
Normal file
@ -0,0 +1,728 @@
|
|||||||
|
"""
|
||||||
|
ein notation:
|
||||||
|
b - batch
|
||||||
|
n - sequence
|
||||||
|
nt - text sequence
|
||||||
|
nw - raw wave length
|
||||||
|
d - dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
from torch import nn
|
||||||
|
from x_transformers.x_transformers import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
# raw wav to mel spec
|
||||||
|
|
||||||
|
|
||||||
|
mel_basis_cache = {}
|
||||||
|
hann_window_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_bigvgan_mel_spectrogram(
|
||||||
|
waveform,
|
||||||
|
n_fft=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
target_sample_rate=24000,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
fmin=0,
|
||||||
|
fmax=None,
|
||||||
|
center=False,
|
||||||
|
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
|
||||||
|
device = waveform.device
|
||||||
|
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
||||||
|
|
||||||
|
if key not in mel_basis_cache:
|
||||||
|
mel = librosa_mel_fn(
|
||||||
|
sr=target_sample_rate,
|
||||||
|
n_fft=n_fft,
|
||||||
|
n_mels=n_mel_channels,
|
||||||
|
fmin=fmin,
|
||||||
|
fmax=fmax,
|
||||||
|
)
|
||||||
|
mel_basis_cache[key] = (
|
||||||
|
torch.from_numpy(mel).float().to(device)
|
||||||
|
) # TODO: why they need .float()?
|
||||||
|
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
||||||
|
|
||||||
|
mel_basis = mel_basis_cache[key]
|
||||||
|
hann_window = hann_window_cache[key]
|
||||||
|
|
||||||
|
padding = (n_fft - hop_length) // 2
|
||||||
|
waveform = torch.nn.functional.pad(
|
||||||
|
waveform.unsqueeze(1), (padding, padding), mode="reflect"
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
|
spec = torch.stft(
|
||||||
|
waveform,
|
||||||
|
n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
window=hann_window,
|
||||||
|
center=center,
|
||||||
|
pad_mode="reflect",
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=True,
|
||||||
|
)
|
||||||
|
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
||||||
|
|
||||||
|
mel_spec = torch.matmul(mel_basis, spec)
|
||||||
|
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
||||||
|
|
||||||
|
return mel_spec
|
||||||
|
|
||||||
|
|
||||||
|
def get_vocos_mel_spectrogram(
|
||||||
|
waveform,
|
||||||
|
n_fft=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
target_sample_rate=24000,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
):
|
||||||
|
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=target_sample_rate,
|
||||||
|
n_fft=n_fft,
|
||||||
|
win_length=win_length,
|
||||||
|
hop_length=hop_length,
|
||||||
|
n_mels=n_mel_channels,
|
||||||
|
power=1,
|
||||||
|
center=True,
|
||||||
|
normalized=False,
|
||||||
|
norm=None,
|
||||||
|
).to(waveform.device)
|
||||||
|
if len(waveform.shape) == 3:
|
||||||
|
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
|
||||||
|
|
||||||
|
assert len(waveform.shape) == 2
|
||||||
|
|
||||||
|
mel = mel_stft(waveform)
|
||||||
|
mel = mel.clamp(min=1e-5).log()
|
||||||
|
return mel
|
||||||
|
|
||||||
|
|
||||||
|
class MelSpec(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
target_sample_rate=24_000,
|
||||||
|
mel_spec_type="vocos",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert mel_spec_type in ["vocos", "bigvgan"], print(
|
||||||
|
"We only support two extract mel backend: vocos or bigvgan"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
self.n_mel_channels = n_mel_channels
|
||||||
|
self.target_sample_rate = target_sample_rate
|
||||||
|
|
||||||
|
if mel_spec_type == "vocos":
|
||||||
|
self.extractor = get_vocos_mel_spectrogram
|
||||||
|
elif mel_spec_type == "bigvgan":
|
||||||
|
self.extractor = get_bigvgan_mel_spectrogram
|
||||||
|
|
||||||
|
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, wav):
|
||||||
|
if self.dummy.device != wav.device:
|
||||||
|
self.to(wav.device)
|
||||||
|
|
||||||
|
mel = self.extractor(
|
||||||
|
waveform=wav,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_mel_channels=self.n_mel_channels,
|
||||||
|
target_sample_rate=self.target_sample_rate,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
return mel
|
||||||
|
|
||||||
|
|
||||||
|
# sinusoidal position embedding
|
||||||
|
|
||||||
|
|
||||||
|
class SinusPositionEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x, scale=1000):
|
||||||
|
device = x.device
|
||||||
|
half_dim = self.dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
||||||
|
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
# convolutional position embedding
|
||||||
|
|
||||||
|
|
||||||
|
class ConvPositionEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, kernel_size=31, groups=16):
|
||||||
|
super().__init__()
|
||||||
|
assert kernel_size % 2 != 0
|
||||||
|
self.conv1d = nn.Sequential(
|
||||||
|
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||||
|
nn.Mish(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask[..., None]
|
||||||
|
x = x.masked_fill(~mask, 0.0)
|
||||||
|
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.conv1d(x)
|
||||||
|
out = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
out = out.masked_fill(~mask, 0.0)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# rotary positional embedding related
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis(
|
||||||
|
dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0
|
||||||
|
):
|
||||||
|
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||||
|
# has some connection to NTK literature
|
||||||
|
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
|
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
||||||
|
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
|
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||||
|
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||||
|
freqs_cos = torch.cos(freqs) # real part
|
||||||
|
freqs_sin = torch.sin(freqs) # imaginary part
|
||||||
|
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
||||||
|
# length = length if isinstance(length, int) else length.max()
|
||||||
|
scale = scale * torch.ones_like(
|
||||||
|
start, dtype=torch.float32
|
||||||
|
) # in case scale is a scalar
|
||||||
|
pos = (
|
||||||
|
start.unsqueeze(1)
|
||||||
|
+ (
|
||||||
|
torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0)
|
||||||
|
* scale.unsqueeze(1)
|
||||||
|
).long()
|
||||||
|
)
|
||||||
|
# avoid extra long error.
|
||||||
|
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
# Global Response Normalization layer (Instance Normalization ?)
|
||||||
|
|
||||||
|
|
||||||
|
class GRN(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
||||||
|
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||||
|
return self.gamma * (x * Nx) + self.beta + x
|
||||||
|
|
||||||
|
|
||||||
|
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
||||||
|
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNeXtV2Block(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
intermediate_dim: int,
|
||||||
|
dilation: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
padding = (dilation * (7 - 1)) // 2
|
||||||
|
self.dwconv = nn.Conv1d(
|
||||||
|
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
||||||
|
) # depthwise conv
|
||||||
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||||
|
self.pwconv1 = nn.Linear(
|
||||||
|
dim, intermediate_dim
|
||||||
|
) # pointwise/1x1 convs, implemented with linear layers
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.grn = GRN(intermediate_dim)
|
||||||
|
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
x = x.transpose(1, 2) # b n d -> b d n
|
||||||
|
x = self.dwconv(x)
|
||||||
|
x = x.transpose(1, 2) # b d n -> b n d
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.pwconv1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.grn(x)
|
||||||
|
x = self.pwconv2(x)
|
||||||
|
return residual + x
|
||||||
|
|
||||||
|
|
||||||
|
# AdaLayerNormZero
|
||||||
|
# return with modulated x for attn input, and params for later mlp modulation
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNormZero(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(dim, dim * 6)
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x, emb=None):
|
||||||
|
emb = self.linear(self.silu(emb))
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(
|
||||||
|
emb, 6, dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||||
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||||
|
|
||||||
|
|
||||||
|
# AdaLayerNormZero for final layer
|
||||||
|
# return only with modulated x for attn input, cuz no more mlp modulation
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNormZero_Final(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(dim, dim * 2)
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x, emb):
|
||||||
|
emb = self.linear(self.silu(emb))
|
||||||
|
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||||
|
|
||||||
|
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# FeedForward
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
dim_out = dim_out if dim_out is not None else dim
|
||||||
|
|
||||||
|
activation = nn.GELU(approximate=approximate)
|
||||||
|
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
||||||
|
self.ff = nn.Sequential(
|
||||||
|
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ff(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Attention with possible joint part
|
||||||
|
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: JointAttnProcessor | AttnProcessor,
|
||||||
|
dim: int,
|
||||||
|
heads: int = 8,
|
||||||
|
dim_head: int = 64,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
context_dim: Optional[int] = None, # if not None -> joint attention
|
||||||
|
context_pre_only=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError(
|
||||||
|
"Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.heads = heads
|
||||||
|
self.inner_dim = dim_head * heads
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
self.context_dim = context_dim
|
||||||
|
self.context_pre_only = context_pre_only
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, self.inner_dim)
|
||||||
|
self.to_k = nn.Linear(dim, self.inner_dim)
|
||||||
|
self.to_v = nn.Linear(dim, self.inner_dim)
|
||||||
|
|
||||||
|
if self.context_dim is not None:
|
||||||
|
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
||||||
|
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
||||||
|
if self.context_pre_only is not None:
|
||||||
|
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
||||||
|
|
||||||
|
self.to_out = nn.ModuleList([])
|
||||||
|
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
||||||
|
self.to_out.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
if self.context_pre_only is not None and not self.context_pre_only:
|
||||||
|
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: float["b n d"], # noised input x # noqa: F722
|
||||||
|
c: float["b n d"] = None, # context c # noqa: F722
|
||||||
|
mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
rope=None, # rotary position embedding for x
|
||||||
|
c_rope=None, # rotary position embedding for c
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if c is not None:
|
||||||
|
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
||||||
|
else:
|
||||||
|
return self.processor(self, x, mask=mask, rope=rope)
|
||||||
|
|
||||||
|
|
||||||
|
# Attention processor
|
||||||
|
|
||||||
|
|
||||||
|
class AttnProcessor:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
x: float["b n d"], # noised input x # noqa: F722
|
||||||
|
mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
rope=None, # rotary position embedding
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
|
# `sample` projections.
|
||||||
|
query = attn.to_q(x)
|
||||||
|
key = attn.to_k(x)
|
||||||
|
value = attn.to_v(x)
|
||||||
|
|
||||||
|
# apply rotary position embedding
|
||||||
|
if rope is not None:
|
||||||
|
freqs, xpos_scale = rope
|
||||||
|
q_xpos_scale, k_xpos_scale = (
|
||||||
|
(xpos_scale, xpos_scale**-1.0)
|
||||||
|
if xpos_scale is not None
|
||||||
|
else (1.0, 1.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||||
|
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||||
|
if mask is not None:
|
||||||
|
attn_mask = mask
|
||||||
|
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||||
|
attn_mask = attn_mask.expand(
|
||||||
|
batch_size, attn.heads, query.shape[-2], key.shape[-2]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_mask = None
|
||||||
|
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
x = x.to(query.dtype)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
x = attn.to_out[0](x)
|
||||||
|
# dropout
|
||||||
|
x = attn.to_out[1](x)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
x = x.masked_fill(~mask, 0.0)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Joint Attention processor for MM-DiT
|
||||||
|
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||||
|
|
||||||
|
|
||||||
|
class JointAttnProcessor:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
x: float["b n d"], # noised input x # noqa: F722
|
||||||
|
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
||||||
|
mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
rope=None, # rotary position embedding for x
|
||||||
|
c_rope=None, # rotary position embedding for c
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
batch_size = c.shape[0]
|
||||||
|
|
||||||
|
# `sample` projections.
|
||||||
|
query = attn.to_q(x)
|
||||||
|
key = attn.to_k(x)
|
||||||
|
value = attn.to_v(x)
|
||||||
|
|
||||||
|
# `context` projections.
|
||||||
|
c_query = attn.to_q_c(c)
|
||||||
|
c_key = attn.to_k_c(c)
|
||||||
|
c_value = attn.to_v_c(c)
|
||||||
|
|
||||||
|
# apply rope for context and noised input independently
|
||||||
|
if rope is not None:
|
||||||
|
freqs, xpos_scale = rope
|
||||||
|
q_xpos_scale, k_xpos_scale = (
|
||||||
|
(xpos_scale, xpos_scale**-1.0)
|
||||||
|
if xpos_scale is not None
|
||||||
|
else (1.0, 1.0)
|
||||||
|
)
|
||||||
|
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||||
|
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||||
|
if c_rope is not None:
|
||||||
|
freqs, xpos_scale = c_rope
|
||||||
|
q_xpos_scale, k_xpos_scale = (
|
||||||
|
(xpos_scale, xpos_scale**-1.0)
|
||||||
|
if xpos_scale is not None
|
||||||
|
else (1.0, 1.0)
|
||||||
|
)
|
||||||
|
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
||||||
|
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
query = torch.cat([query, c_query], dim=1)
|
||||||
|
key = torch.cat([key, c_key], dim=1)
|
||||||
|
value = torch.cat([value, c_value], dim=1)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||||
|
if mask is not None:
|
||||||
|
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
||||||
|
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||||
|
attn_mask = attn_mask.expand(
|
||||||
|
batch_size, attn.heads, query.shape[-2], key.shape[-2]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_mask = None
|
||||||
|
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
x = x.to(query.dtype)
|
||||||
|
|
||||||
|
# Split the attention outputs.
|
||||||
|
x, c = (
|
||||||
|
x[:, : residual.shape[1]],
|
||||||
|
x[:, residual.shape[1] :],
|
||||||
|
)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
x = attn.to_out[0](x)
|
||||||
|
# dropout
|
||||||
|
x = attn.to_out[1](x)
|
||||||
|
if not attn.context_pre_only:
|
||||||
|
c = attn.to_out_c(c)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
x = x.masked_fill(~mask, 0.0)
|
||||||
|
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
||||||
|
|
||||||
|
return x, c
|
||||||
|
|
||||||
|
|
||||||
|
# DiT Block
|
||||||
|
|
||||||
|
|
||||||
|
class DiTBlock(nn.Module):
|
||||||
|
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn_norm = AdaLayerNormZero(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
processor=AttnProcessor(),
|
||||||
|
dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff = FeedForward(
|
||||||
|
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
||||||
|
# pre-norm & modulation for attention input
|
||||||
|
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
||||||
|
|
||||||
|
# process attention output for input x
|
||||||
|
x = x + gate_msa.unsqueeze(1) * attn_output
|
||||||
|
|
||||||
|
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||||
|
ff_output = self.ff(norm)
|
||||||
|
x = x + gate_mlp.unsqueeze(1) * ff_output
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# MMDiT Block https://arxiv.org/abs/2403.03206
|
||||||
|
|
||||||
|
|
||||||
|
class MMDiTBlock(nn.Module):
|
||||||
|
r"""
|
||||||
|
modified from diffusers/src/diffusers/models/attention.py
|
||||||
|
|
||||||
|
notes.
|
||||||
|
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
||||||
|
_x: noised input related. (right part)
|
||||||
|
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.context_pre_only = context_pre_only
|
||||||
|
|
||||||
|
self.attn_norm_c = (
|
||||||
|
AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
||||||
|
)
|
||||||
|
self.attn_norm_x = AdaLayerNormZero(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
processor=JointAttnProcessor(),
|
||||||
|
dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
dropout=dropout,
|
||||||
|
context_dim=dim,
|
||||||
|
context_pre_only=context_pre_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context_pre_only:
|
||||||
|
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_c = FeedForward(
|
||||||
|
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.ff_norm_c = None
|
||||||
|
self.ff_c = None
|
||||||
|
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_x = FeedForward(
|
||||||
|
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x, c, t, mask=None, rope=None, c_rope=None
|
||||||
|
): # x: noised input, c: context, t: time embedding
|
||||||
|
# pre-norm & modulation for attention input
|
||||||
|
if self.context_pre_only:
|
||||||
|
norm_c = self.attn_norm_c(c, t)
|
||||||
|
else:
|
||||||
|
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(
|
||||||
|
c, emb=t
|
||||||
|
)
|
||||||
|
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(
|
||||||
|
x, emb=t
|
||||||
|
)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
x_attn_output, c_attn_output = self.attn(
|
||||||
|
x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope
|
||||||
|
)
|
||||||
|
|
||||||
|
# process attention output for context c
|
||||||
|
if self.context_pre_only:
|
||||||
|
c = None
|
||||||
|
else: # if not last layer
|
||||||
|
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
||||||
|
|
||||||
|
norm_c = (
|
||||||
|
self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||||
|
)
|
||||||
|
c_ff_output = self.ff_c(norm_c)
|
||||||
|
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
||||||
|
|
||||||
|
# process attention output for input x
|
||||||
|
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
||||||
|
|
||||||
|
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
||||||
|
x_ff_output = self.ff_x(norm_x)
|
||||||
|
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
||||||
|
|
||||||
|
return c, x
|
||||||
|
|
||||||
|
|
||||||
|
# time step conditioning embedding
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, freq_embed_dim=256):
|
||||||
|
super().__init__()
|
||||||
|
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
||||||
|
self.time_mlp = nn.Sequential(
|
||||||
|
nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, timestep: float["b"]): # noqa: F821
|
||||||
|
time_hidden = self.time_embed(timestep)
|
||||||
|
time_hidden = time_hidden.to(timestep.dtype)
|
||||||
|
time = self.time_mlp(time_hidden) # b d
|
||||||
|
return time
|
206
egs/wenetspeech4tts/TTS/f5-tts/model/utils.py
Normal file
206
egs/wenetspeech4tts/TTS/f5-tts/model/utils.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from importlib.resources import files
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
import torch
|
||||||
|
from pypinyin import Style, lazy_pinyin
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
# seed everything
|
||||||
|
|
||||||
|
|
||||||
|
def seed_everything(seed=0):
|
||||||
|
random.seed(seed)
|
||||||
|
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
# helpers
|
||||||
|
|
||||||
|
|
||||||
|
def exists(v):
|
||||||
|
return v is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(v, d):
|
||||||
|
return v if exists(v) else d
|
||||||
|
|
||||||
|
|
||||||
|
# tensor helpers
|
||||||
|
|
||||||
|
|
||||||
|
def lens_to_mask(
|
||||||
|
t: int["b"], length: int | None = None # noqa: F722 F821
|
||||||
|
) -> bool["b n"]: # noqa: F722 F821
|
||||||
|
if not exists(length):
|
||||||
|
length = t.amax()
|
||||||
|
|
||||||
|
seq = torch.arange(length, device=t.device)
|
||||||
|
return seq[None, :] < t[:, None]
|
||||||
|
|
||||||
|
|
||||||
|
def mask_from_start_end_indices(
|
||||||
|
seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821
|
||||||
|
):
|
||||||
|
max_seq_len = seq_len.max().item()
|
||||||
|
seq = torch.arange(max_seq_len, device=start.device).long()
|
||||||
|
start_mask = seq[None, :] >= start[:, None]
|
||||||
|
end_mask = seq[None, :] < end[:, None]
|
||||||
|
return start_mask & end_mask
|
||||||
|
|
||||||
|
|
||||||
|
def mask_from_frac_lengths(
|
||||||
|
seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821
|
||||||
|
):
|
||||||
|
lengths = (frac_lengths * seq_len).long()
|
||||||
|
max_start = seq_len - lengths
|
||||||
|
|
||||||
|
rand = torch.rand_like(frac_lengths)
|
||||||
|
start = (max_start * rand).long().clamp(min=0)
|
||||||
|
end = start + lengths
|
||||||
|
|
||||||
|
return mask_from_start_end_indices(seq_len, start, end)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_masked_mean(
|
||||||
|
t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821
|
||||||
|
) -> float["b d"]: # noqa: F722 F821
|
||||||
|
if not exists(mask):
|
||||||
|
return t.mean(dim=1)
|
||||||
|
|
||||||
|
t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
|
||||||
|
num = t.sum(dim=1)
|
||||||
|
den = mask.float().sum(dim=1)
|
||||||
|
|
||||||
|
return num / den.clamp(min=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
# simple utf-8 tokenizer, since paper went character based
|
||||||
|
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
|
||||||
|
list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
|
||||||
|
text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# char tokenizer, based on custom dataset's extracted .txt file
|
||||||
|
def list_str_to_idx(
|
||||||
|
text: list[str] | list[list[str]],
|
||||||
|
vocab_char_map: dict[str, int], # {char: idx}
|
||||||
|
padding_value=-1,
|
||||||
|
) -> int["b nt"]: # noqa: F722
|
||||||
|
list_idx_tensors = [
|
||||||
|
torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text
|
||||||
|
] # pinyin or char style
|
||||||
|
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# Get tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
||||||
|
"""
|
||||||
|
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
||||||
|
- "char" for char-wise tokenizer, need .txt vocab_file
|
||||||
|
- "byte" for utf-8 tokenizer
|
||||||
|
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
||||||
|
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
||||||
|
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
||||||
|
- if use "byte", set to 256 (unicode byte range)
|
||||||
|
"""
|
||||||
|
if tokenizer in ["pinyin", "char"]:
|
||||||
|
tokenizer_path = os.path.join(
|
||||||
|
files("f5_tts").joinpath("../../data"),
|
||||||
|
f"{dataset_name}_{tokenizer}/vocab.txt",
|
||||||
|
)
|
||||||
|
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
||||||
|
vocab_char_map = {}
|
||||||
|
for i, char in enumerate(f):
|
||||||
|
vocab_char_map[char[:-1]] = i
|
||||||
|
vocab_size = len(vocab_char_map)
|
||||||
|
assert (
|
||||||
|
vocab_char_map[" "] == 0
|
||||||
|
), "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
|
||||||
|
|
||||||
|
elif tokenizer == "byte":
|
||||||
|
vocab_char_map = None
|
||||||
|
vocab_size = 256
|
||||||
|
|
||||||
|
elif tokenizer == "custom":
|
||||||
|
with open(dataset_name, "r", encoding="utf-8") as f:
|
||||||
|
vocab_char_map = {}
|
||||||
|
for i, char in enumerate(f):
|
||||||
|
vocab_char_map[char[:-1]] = i
|
||||||
|
vocab_size = len(vocab_char_map)
|
||||||
|
|
||||||
|
return vocab_char_map, vocab_size
|
||||||
|
|
||||||
|
|
||||||
|
# convert char to pinyin
|
||||||
|
|
||||||
|
jieba.initialize()
|
||||||
|
print("Word segmentation module jieba initialized.\n")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_char_to_pinyin(text_list, polyphone=True):
|
||||||
|
final_text_list = []
|
||||||
|
custom_trans = str.maketrans(
|
||||||
|
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
||||||
|
) # add custom trans here, to address oov
|
||||||
|
|
||||||
|
def is_chinese(c):
|
||||||
|
return "\u3100" <= c <= "\u9fff" # common chinese characters
|
||||||
|
|
||||||
|
for text in text_list:
|
||||||
|
char_list = []
|
||||||
|
text = text.translate(custom_trans)
|
||||||
|
for seg in jieba.cut(text):
|
||||||
|
seg_byte_len = len(bytes(seg, "UTF-8"))
|
||||||
|
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
||||||
|
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
||||||
|
char_list.append(" ")
|
||||||
|
char_list.extend(seg)
|
||||||
|
elif polyphone and seg_byte_len == 3 * len(
|
||||||
|
seg
|
||||||
|
): # if pure east asian characters
|
||||||
|
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
||||||
|
for i, c in enumerate(seg):
|
||||||
|
if is_chinese(c):
|
||||||
|
char_list.append(" ")
|
||||||
|
char_list.append(seg_[i])
|
||||||
|
else: # if mixed characters, alphabets and symbols
|
||||||
|
for c in seg:
|
||||||
|
if ord(c) < 256:
|
||||||
|
char_list.extend(c)
|
||||||
|
elif is_chinese(c):
|
||||||
|
char_list.append(" ")
|
||||||
|
char_list.extend(
|
||||||
|
lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
char_list.append(c)
|
||||||
|
final_text_list.append(char_list)
|
||||||
|
|
||||||
|
return final_text_list
|
||||||
|
|
||||||
|
|
||||||
|
# filter func for dirty data with many repetitions
|
||||||
|
|
||||||
|
|
||||||
|
def repetition_found(text, length=2, tolerance=10):
|
||||||
|
pattern_count = defaultdict(int)
|
||||||
|
for i in range(len(text) - length + 1):
|
||||||
|
pattern = text[i : i + length]
|
||||||
|
pattern_count[pattern] += 1
|
||||||
|
for pattern, count in pattern_count.items():
|
||||||
|
if count > tolerance:
|
||||||
|
return True
|
||||||
|
return False
|
104
egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py
Normal file
104
egs/wenetspeech4tts/TTS/f5-tts/speech_synthesis.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
from typing import Callable, Dict, List, Sequence, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import validate
|
||||||
|
from lhotse.cut import CutSet
|
||||||
|
from lhotse.dataset.collation import collate_audio
|
||||||
|
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||||
|
from lhotse.utils import ifnone
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechSynthesisDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
The PyTorch Dataset for the speech synthesis task.
|
||||||
|
Each item in this dataset is a dict of:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
{
|
||||||
|
'audio': (B x NumSamples) float tensor
|
||||||
|
'features': (B x NumFrames x NumFeatures) float tensor
|
||||||
|
'audio_lens': (B, ) int tensor
|
||||||
|
'features_lens': (B, ) int tensor
|
||||||
|
'text': List[str] of len B # when return_text=True
|
||||||
|
'tokens': List[List[str]] # when return_tokens=True
|
||||||
|
'speakers': List[str] of len B # when return_spk_ids=True
|
||||||
|
'cut': List of Cuts # when return_cuts=True
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
||||||
|
feature_input_strategy: BatchIO = PrecomputedFeatures(),
|
||||||
|
feature_transforms: Union[Sequence[Callable], Callable] = None,
|
||||||
|
return_text: bool = True,
|
||||||
|
return_tokens: bool = False,
|
||||||
|
return_spk_ids: bool = False,
|
||||||
|
return_cuts: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cut_transforms = ifnone(cut_transforms, [])
|
||||||
|
self.feature_input_strategy = feature_input_strategy
|
||||||
|
|
||||||
|
self.return_text = return_text
|
||||||
|
self.return_tokens = return_tokens
|
||||||
|
self.return_spk_ids = return_spk_ids
|
||||||
|
self.return_cuts = return_cuts
|
||||||
|
|
||||||
|
if feature_transforms is None:
|
||||||
|
feature_transforms = []
|
||||||
|
elif not isinstance(feature_transforms, Sequence):
|
||||||
|
feature_transforms = [feature_transforms]
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
isinstance(transform, Callable) for transform in feature_transforms
|
||||||
|
), "Feature transforms must be Callable"
|
||||||
|
self.feature_transforms = feature_transforms
|
||||||
|
|
||||||
|
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
|
||||||
|
validate_for_tts(cuts)
|
||||||
|
|
||||||
|
for transform in self.cut_transforms:
|
||||||
|
cuts = transform(cuts)
|
||||||
|
|
||||||
|
# audio, audio_lens = collate_audio(cuts)
|
||||||
|
features, features_lens = self.feature_input_strategy(cuts)
|
||||||
|
|
||||||
|
for transform in self.feature_transforms:
|
||||||
|
features = transform(features)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
# "audio": audio,
|
||||||
|
"features": features,
|
||||||
|
# "audio_lens": audio_lens,
|
||||||
|
"features_lens": features_lens,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.return_text:
|
||||||
|
# use normalized text
|
||||||
|
# text = [cut.supervisions[0].normalized_text for cut in cuts]
|
||||||
|
text = [cut.supervisions[0].text for cut in cuts]
|
||||||
|
batch["text"] = text
|
||||||
|
|
||||||
|
if self.return_tokens:
|
||||||
|
# tokens = [cut.tokens for cut in cuts]
|
||||||
|
tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts]
|
||||||
|
batch["tokens"] = tokens
|
||||||
|
|
||||||
|
if self.return_spk_ids:
|
||||||
|
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
|
||||||
|
|
||||||
|
if self.return_cuts:
|
||||||
|
batch["cut"] = [cut for cut in cuts]
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def validate_for_tts(cuts: CutSet) -> None:
|
||||||
|
validate(cuts)
|
||||||
|
for cut in cuts:
|
||||||
|
assert (
|
||||||
|
len(cut.supervisions) == 1
|
||||||
|
), "Only the Cuts with single supervision are supported."
|
1178
egs/wenetspeech4tts/TTS/f5-tts/train.py
Executable file
1178
egs/wenetspeech4tts/TTS/f5-tts/train.py
Executable file
File diff suppressed because it is too large
Load Diff
306
egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py
Normal file
306
egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
# Copyright 2021 Piotr Żelasko
|
||||||
|
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset,
|
||||||
|
CutConcatenate,
|
||||||
|
CutMix,
|
||||||
|
DynamicBucketingSampler,
|
||||||
|
PrecomputedFeatures,
|
||||||
|
SimpleCutSampler,
|
||||||
|
)
|
||||||
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
|
AudioSamples,
|
||||||
|
OnTheFlyFeatures,
|
||||||
|
)
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from speech_synthesis import SpeechSynthesisDataset # noqa F401
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, worker_id: int):
|
||||||
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
|
class TtsDataModule:
|
||||||
|
"""
|
||||||
|
DataModule for tts experiments.
|
||||||
|
It assumes there is always one train and valid dataloader,
|
||||||
|
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||||
|
and test-other).
|
||||||
|
|
||||||
|
It contains all the common data pipeline modules used in ASR
|
||||||
|
experiments, e.g.:
|
||||||
|
- dynamic batch size,
|
||||||
|
- bucketing samplers,
|
||||||
|
- cut concatenation,
|
||||||
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
|
This class should be derived for specific corpora used in ASR tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group(
|
||||||
|
title="TTS data related options",
|
||||||
|
description="These options are used for the preparation of "
|
||||||
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
|
"augmentations, etc.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--manifest-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/fbank"),
|
||||||
|
help="Path to directory with train/valid/test cuts.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--max-duration",
|
||||||
|
type=int,
|
||||||
|
default=200.0,
|
||||||
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--bucketing-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, the batches will come from buckets of "
|
||||||
|
"similar duration (saves padding frames).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-buckets",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
|
"(you might want to increase it for larger datasets).",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--on-the-fly-feats",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
|
"if available.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--shuffle",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled (=default), the examples will be "
|
||||||
|
"shuffled for each epoch.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--drop-last",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to drop last batch. Used by sampler.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, each batch will have the "
|
||||||
|
"field: batch['cut'] with the cuts that "
|
||||||
|
"were used to construct it.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--input-strategy",
|
||||||
|
type=str,
|
||||||
|
default="PrecomputedFeatures",
|
||||||
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefix",
|
||||||
|
type=str,
|
||||||
|
default="wenetspeech4tts",
|
||||||
|
help="prefix of the manifest file",
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_train: CutSet,
|
||||||
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> DataLoader:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cuts_train:
|
||||||
|
CutSet for training.
|
||||||
|
sampler_state_dict:
|
||||||
|
The state dict for the training sampler.
|
||||||
|
"""
|
||||||
|
logging.info("About to create train dataset")
|
||||||
|
train = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=False,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"On-the-fly feature extraction is not implemented yet."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.bucketing_sampler:
|
||||||
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
|
train_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
buffer_size=self.args.num_buckets * 2000,
|
||||||
|
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using SimpleCutSampler.")
|
||||||
|
train_sampler = SimpleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
)
|
||||||
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
|
if sampler_state_dict is not None:
|
||||||
|
logging.info("Loading sampler state dict")
|
||||||
|
train_sampler.load_state_dict(sampler_state_dict)
|
||||||
|
|
||||||
|
# 'seed' is derived from the current random state, which will have
|
||||||
|
# previously been set in the main process.
|
||||||
|
seed = torch.randint(0, 100000, ()).item()
|
||||||
|
worker_init_fn = _SeedWorkers(seed)
|
||||||
|
|
||||||
|
train_dl = DataLoader(
|
||||||
|
train,
|
||||||
|
sampler=train_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
persistent_workers=True,
|
||||||
|
pin_memory=True,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dl
|
||||||
|
|
||||||
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
|
logging.info("About to create dev dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"On-the-fly feature extraction is not implemented yet."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validate = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=False,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
valid_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_valid,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create valid dataloader")
|
||||||
|
valid_dl = DataLoader(
|
||||||
|
validate,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=2,
|
||||||
|
persistent_workers=True,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return valid_dl
|
||||||
|
|
||||||
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
|
logging.info("About to create test dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"On-the-fly feature extraction is not implemented yet."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
test = SpeechSynthesisDataset(
|
||||||
|
return_text=True,
|
||||||
|
return_tokens=False,
|
||||||
|
feature_input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
test_sampler = DynamicBucketingSampler(
|
||||||
|
cuts,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create test dataloader")
|
||||||
|
test_dl = DataLoader(
|
||||||
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=test_sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
)
|
||||||
|
return test_dl
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / f"{self.args.prefix}_cuts_train.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def valid_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get validation cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / f"{self.args.prefix}_cuts_valid.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get test cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / f"{self.args.prefix}_cuts_test.jsonl.gz"
|
||||||
|
)
|
1
egs/wenetspeech4tts/TTS/f5-tts/utils.py
Symbolic link
1
egs/wenetspeech4tts/TTS/f5-tts/utils.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../ljspeech/TTS/matcha/utils.py
|
2545
egs/wenetspeech4tts/TTS/f5-tts/vocab.txt
Normal file
2545
egs/wenetspeech4tts/TTS/f5-tts/vocab.txt
Normal file
File diff suppressed because it is too large
Load Diff
122
egs/wenetspeech4tts/TTS/local/audio.py
Normal file
122
egs/wenetspeech4tts/TTS/local/audio.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import random
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# from env import AttrDict
|
||||||
|
|
||||||
|
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||||
|
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_decompression(x, C=1):
|
||||||
|
return np.exp(x) / C
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||||
|
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_decompression_torch(x, C=1):
|
||||||
|
return torch.exp(x) / C
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_normalize_torch(magnitudes):
|
||||||
|
return dynamic_range_compression_torch(magnitudes)
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_de_normalize_torch(magnitudes):
|
||||||
|
return dynamic_range_decompression_torch(magnitudes)
|
||||||
|
|
||||||
|
|
||||||
|
mel_basis_cache = {}
|
||||||
|
hann_window_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def mel_spectrogram(
|
||||||
|
y: torch.Tensor,
|
||||||
|
n_fft: int = 1024,
|
||||||
|
num_mels: int = 100,
|
||||||
|
sampling_rate: int = 24_000,
|
||||||
|
hop_size: int = 256,
|
||||||
|
win_size: int = 1024,
|
||||||
|
fmin: int = 0,
|
||||||
|
fmax: int = None,
|
||||||
|
center: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Calculate the mel spectrogram of an input signal.
|
||||||
|
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Input signal.
|
||||||
|
n_fft (int): FFT size.
|
||||||
|
num_mels (int): Number of mel bins.
|
||||||
|
sampling_rate (int): Sampling rate of the input signal.
|
||||||
|
hop_size (int): Hop size for STFT.
|
||||||
|
win_size (int): Window size for STFT.
|
||||||
|
fmin (int): Minimum frequency for mel filterbank.
|
||||||
|
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
|
||||||
|
center (bool): Whether to pad the input to center the frames. Default is False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Mel spectrogram.
|
||||||
|
"""
|
||||||
|
if torch.min(y) < -1.0:
|
||||||
|
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
|
||||||
|
if torch.max(y) > 1.0:
|
||||||
|
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
|
||||||
|
|
||||||
|
device = y.device
|
||||||
|
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
||||||
|
|
||||||
|
if key not in mel_basis_cache:
|
||||||
|
mel = librosa_mel_fn(
|
||||||
|
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||||
|
)
|
||||||
|
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
||||||
|
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
||||||
|
|
||||||
|
mel_basis = mel_basis_cache[key]
|
||||||
|
hann_window = hann_window_cache[key]
|
||||||
|
|
||||||
|
padding = (n_fft - hop_size) // 2
|
||||||
|
y = torch.nn.functional.pad(
|
||||||
|
y.unsqueeze(1), (padding, padding), mode="reflect"
|
||||||
|
).squeeze(1)
|
||||||
|
|
||||||
|
spec = torch.stft(
|
||||||
|
y,
|
||||||
|
n_fft,
|
||||||
|
hop_length=hop_size,
|
||||||
|
win_length=win_size,
|
||||||
|
window=hann_window,
|
||||||
|
center=center,
|
||||||
|
pad_mode="reflect",
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=True,
|
||||||
|
)
|
||||||
|
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
||||||
|
|
||||||
|
mel_spec = torch.matmul(mel_basis, spec)
|
||||||
|
mel_spec = spectral_normalize_torch(mel_spec)
|
||||||
|
|
||||||
|
return mel_spec
|
218
egs/wenetspeech4tts/TTS/local/compute_mel_feat.py
Executable file
218
egs/wenetspeech4tts/TTS/local/compute_mel_feat.py
Executable file
@ -0,0 +1,218 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file computes fbank features of the LJSpeech dataset.
|
||||||
|
It looks for manifests in the directory data/manifests.
|
||||||
|
|
||||||
|
The generated fbank features are saved in data/fbank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from fbank import MatchaFbank, MatchaFbankConfig
|
||||||
|
from lhotse import CutSet, LilcomChunkyWriter
|
||||||
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
|
from icefall.utils import get_executor
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-jobs",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--src-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/manifests"),
|
||||||
|
help="Path to the manifest files",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/fbank"),
|
||||||
|
help="Path to the tokenized files",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-parts",
|
||||||
|
type=str,
|
||||||
|
default="Basic",
|
||||||
|
help="Space separated dataset parts",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefix",
|
||||||
|
type=str,
|
||||||
|
default="wenetspeech4tts",
|
||||||
|
help="prefix of the manifest file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--suffix",
|
||||||
|
type=str,
|
||||||
|
default="jsonl.gz",
|
||||||
|
help="suffix of the manifest file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--split",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Split the cut_set into multiple parts",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--resample-to-24kHz",
|
||||||
|
default=True,
|
||||||
|
help="Resample the audio to 24kHz",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--extractor",
|
||||||
|
type=str,
|
||||||
|
choices=["bigvgan", "hifigan"],
|
||||||
|
default="bigvgan",
|
||||||
|
help="The type of extractor to use",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def compute_fbank(args):
|
||||||
|
src_dir = Path(args.src_dir)
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
num_jobs = min(args.num_jobs, os.cpu_count())
|
||||||
|
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip().split(" ")
|
||||||
|
|
||||||
|
logging.info(f"num_jobs: {num_jobs}")
|
||||||
|
logging.info(f"src_dir: {src_dir}")
|
||||||
|
logging.info(f"output_dir: {output_dir}")
|
||||||
|
logging.info(f"dataset_parts: {dataset_parts}")
|
||||||
|
if args.extractor == "bigvgan":
|
||||||
|
config = MatchaFbankConfig(
|
||||||
|
n_fft=1024,
|
||||||
|
n_mels=100,
|
||||||
|
sampling_rate=24_000,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
f_min=0,
|
||||||
|
f_max=None,
|
||||||
|
)
|
||||||
|
elif args.extractor == "hifigan":
|
||||||
|
config = MatchaFbankConfig(
|
||||||
|
n_fft=1024,
|
||||||
|
n_mels=80,
|
||||||
|
sampling_rate=22050,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
f_min=0,
|
||||||
|
f_max=8000,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Extractor {args.extractor} is not implemented")
|
||||||
|
|
||||||
|
extractor = MatchaFbank(config)
|
||||||
|
|
||||||
|
manifests = read_manifests_if_cached(
|
||||||
|
dataset_parts=dataset_parts,
|
||||||
|
output_dir=args.src_dir,
|
||||||
|
prefix=args.prefix,
|
||||||
|
suffix=args.suffix,
|
||||||
|
types=["recordings", "supervisions", "cuts"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with get_executor() as ex:
|
||||||
|
for partition, m in manifests.items():
|
||||||
|
logging.info(
|
||||||
|
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
cut_set = CutSet.from_manifests(
|
||||||
|
recordings=m["recordings"],
|
||||||
|
supervisions=m["supervisions"],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
cut_set = m["cuts"]
|
||||||
|
|
||||||
|
if args.split > 1:
|
||||||
|
cut_sets = cut_set.split(args.split)
|
||||||
|
else:
|
||||||
|
cut_sets = [cut_set]
|
||||||
|
|
||||||
|
for idx, part in enumerate(cut_sets):
|
||||||
|
if args.split > 1:
|
||||||
|
storage_path = f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}_{idx}"
|
||||||
|
else:
|
||||||
|
storage_path = (
|
||||||
|
f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.resample_to_24kHz:
|
||||||
|
part = part.resample(24000)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
part = part.compute_and_store_features(
|
||||||
|
extractor=extractor,
|
||||||
|
storage_path=storage_path,
|
||||||
|
num_jobs=num_jobs if ex is None else 64,
|
||||||
|
executor=ex,
|
||||||
|
storage_type=LilcomChunkyWriter,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.split > 1:
|
||||||
|
cuts_filename = (
|
||||||
|
f"{args.prefix}_cuts_{partition}.{idx}.{args.suffix}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cuts_filename = f"{args.prefix}_cuts_{partition}.{args.suffix}"
|
||||||
|
|
||||||
|
part.to_file(f"{args.output_dir}/{cuts_filename}")
|
||||||
|
logging.info(f"Saved {cuts_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
# it wastes a lot of CPU and slow things down.
|
||||||
|
# Do this outside of main() in case it needs to take effect
|
||||||
|
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
compute_fbank(args)
|
26
egs/wenetspeech4tts/TTS/local/compute_wer.sh
Normal file
26
egs/wenetspeech4tts/TTS/local/compute_wer.sh
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
wav_dir=$1
|
||||||
|
wav_files=$(ls $wav_dir/*.wav)
|
||||||
|
# if wav_files is empty, then exit
|
||||||
|
if [ -z "$wav_files" ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
label_file=$2
|
||||||
|
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
|
||||||
|
|
||||||
|
if [ ! -d $model_path ]; then
|
||||||
|
pip install sherpa-onnx
|
||||||
|
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
|
||||||
|
fi
|
||||||
|
|
||||||
|
python3 local/offline-decode-files.py \
|
||||||
|
--tokens=$model_path/tokens.txt \
|
||||||
|
--paraformer=$model_path/model.int8.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
--sample-rate=24000 \
|
||||||
|
--log-dir $wav_dir \
|
||||||
|
--feature-dim=80 \
|
||||||
|
--label $label_file \
|
||||||
|
$wav_files
|
1
egs/wenetspeech4tts/TTS/local/fbank.py
Symbolic link
1
egs/wenetspeech4tts/TTS/local/fbank.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../ljspeech/TTS/matcha/fbank.py
|
495
egs/wenetspeech4tts/TTS/local/offline-decode-files.py
Executable file
495
egs/wenetspeech4tts/TTS/local/offline-decode-files.py
Executable file
@ -0,0 +1,495 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright (c) 2023 by manyeyes
|
||||||
|
# Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
||||||
|
file(s) with a non-streaming model.
|
||||||
|
|
||||||
|
(1) For paraformer
|
||||||
|
|
||||||
|
./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=/path/to/tokens.txt \
|
||||||
|
--paraformer=/path/to/paraformer.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
--sample-rate=16000 \
|
||||||
|
--feature-dim=80 \
|
||||||
|
/path/to/0.wav \
|
||||||
|
/path/to/1.wav
|
||||||
|
|
||||||
|
(2) For transducer models from icefall
|
||||||
|
|
||||||
|
./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=/path/to/tokens.txt \
|
||||||
|
--encoder=/path/to/encoder.onnx \
|
||||||
|
--decoder=/path/to/decoder.onnx \
|
||||||
|
--joiner=/path/to/joiner.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
--sample-rate=16000 \
|
||||||
|
--feature-dim=80 \
|
||||||
|
/path/to/0.wav \
|
||||||
|
/path/to/1.wav
|
||||||
|
|
||||||
|
(3) For CTC models from NeMo
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
|
||||||
|
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
|
||||||
|
--num-threads=2 \
|
||||||
|
--decoding-method=greedy_search \
|
||||||
|
--debug=false \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(4) For Whisper models
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||||
|
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||||
|
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||||
|
--whisper-task=transcribe \
|
||||||
|
--num-threads=1 \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(5) For CTC models from WeNet
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
||||||
|
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
||||||
|
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
|
||||||
|
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
|
||||||
|
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
|
||||||
|
|
||||||
|
(6) For tdnn models of the yesno recipe from icefall
|
||||||
|
|
||||||
|
python3 ./python-api-examples/offline-decode-files.py \
|
||||||
|
--sample-rate=8000 \
|
||||||
|
--feature-dim=23 \
|
||||||
|
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||||
|
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
|
||||||
|
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
|
||||||
|
|
||||||
|
Please refer to
|
||||||
|
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||||
|
to install sherpa-onnx and to download non-streaming pre-trained models
|
||||||
|
used in this file.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import wave
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="Path to tokens.txt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hotwords-file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The file containing hotwords, one words/phrases per line, like
|
||||||
|
HELLO WORLD
|
||||||
|
你好世界
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--hotwords-score",
|
||||||
|
type=float,
|
||||||
|
default=1.5,
|
||||||
|
help="""
|
||||||
|
The hotword score of each token for biasing word/phrase. Used only if
|
||||||
|
--hotwords-file is given.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--modeling-unit",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
|
||||||
|
Used only when hotwords-file is given.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-vocab",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""
|
||||||
|
The path to the bpe vocabulary, the bpe vocabulary is generated by
|
||||||
|
sentencepiece, you can also export the bpe vocabulary through a bpe model
|
||||||
|
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
|
||||||
|
and modeling-unit is bpe or cjkchar+bpe.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the joiner model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--paraformer",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from Paraformer",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nemo-ctc",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from NeMo CTC",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--wenet-ctc",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx from WeNet CTC",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tdnn-model",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to the model.onnx for the tdnn model of the yesno recipe",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-threads",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of threads for neural network computation",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-encoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to whisper encoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-decoder",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Path to whisper decoder model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-language",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="""It specifies the spoken language in the input audio file.
|
||||||
|
Example values: en, fr, de, zh, jp.
|
||||||
|
Available languages for multilingual models can be found at
|
||||||
|
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||||
|
If not specified, we infer the language from the input audio file.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-task",
|
||||||
|
default="transcribe",
|
||||||
|
choices=["transcribe", "translate"],
|
||||||
|
type=str,
|
||||||
|
help="""For multilingual models, if you specify translate, the output
|
||||||
|
will be in English.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper-tail-paddings",
|
||||||
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help="""Number of tail padding frames.
|
||||||
|
We have removed the 30-second constraint from whisper, so you need to
|
||||||
|
choose the amount of tail padding frames by yourself.
|
||||||
|
Use -1 to use a default value for tail padding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="Valid values are greedy_search and modified_beam_search",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="True to show debug messages",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="""Sample rate of the feature extractor. Must match the one
|
||||||
|
expected by the model. Note: The input sound files can have a
|
||||||
|
different sample rate from this argument.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--feature-dim",
|
||||||
|
type=int,
|
||||||
|
default=80,
|
||||||
|
help="Feature dimension. Must match the one expected by the model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to decode. Each file must be of WAVE"
|
||||||
|
"format with a single channel, and each sample has 16-bit, "
|
||||||
|
"i.e., int16_t. "
|
||||||
|
"The sample rate of the file can be arbitrary and does not need to "
|
||||||
|
"be 16 kHz",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--name",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The directory containing the input sound files to decode",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-dir",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The directory containing the input sound files to decode",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--label",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="wav_base_name label",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def assert_file_exists(filename: str):
|
||||||
|
assert Path(filename).is_file(), (
|
||||||
|
f"{filename} does not exist!\n"
|
||||||
|
"Please refer to "
|
||||||
|
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
wave_filename:
|
||||||
|
Path to a wave file. It should be single channel and can be of type
|
||||||
|
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- A 1-D array of dtype np.float32 containing the samples,
|
||||||
|
which are normalized to the range [-1, 1].
|
||||||
|
- Sample rate of the wave file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
samples, sample_rate = sf.read(wave_filename, dtype="float32")
|
||||||
|
assert (
|
||||||
|
samples.ndim == 1
|
||||||
|
), f"Expected single channel, but got {samples.ndim} channels."
|
||||||
|
|
||||||
|
samples_float32 = samples.astype(np.float32)
|
||||||
|
|
||||||
|
return samples_float32, sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_text_alimeeting(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Text normalization similar to M2MeT challenge baseline.
|
||||||
|
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
text = text.replace(" ", "")
|
||||||
|
text = text.replace("<sil>", "")
|
||||||
|
text = text.replace("<%>", "")
|
||||||
|
text = text.replace("<->", "")
|
||||||
|
text = text.replace("<$>", "")
|
||||||
|
text = text.replace("<#>", "")
|
||||||
|
text = text.replace("<_>", "")
|
||||||
|
text = text.replace("<space>", "")
|
||||||
|
text = text.replace("`", "")
|
||||||
|
text = text.replace("&", "")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
if re.search("[a-zA-Z]", text):
|
||||||
|
text = text.upper()
|
||||||
|
text = text.replace("A", "A")
|
||||||
|
text = text.replace("a", "A")
|
||||||
|
text = text.replace("b", "B")
|
||||||
|
text = text.replace("c", "C")
|
||||||
|
text = text.replace("k", "K")
|
||||||
|
text = text.replace("t", "T")
|
||||||
|
text = text.replace(",", "")
|
||||||
|
text = text.replace("丶", "")
|
||||||
|
text = text.replace("。", "")
|
||||||
|
text = text.replace("、", "")
|
||||||
|
text = text.replace("?", "")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
assert_file_exists(args.tokens)
|
||||||
|
assert args.num_threads > 0, args.num_threads
|
||||||
|
|
||||||
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||||
|
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||||
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||||
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||||
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||||
|
|
||||||
|
assert_file_exists(args.paraformer)
|
||||||
|
|
||||||
|
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||||
|
paraformer=args.paraformer,
|
||||||
|
tokens=args.tokens,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
sample_rate=args.sample_rate,
|
||||||
|
feature_dim=args.feature_dim,
|
||||||
|
decoding_method=args.decoding_method,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Started!")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
streams, results = [], []
|
||||||
|
total_duration = 0
|
||||||
|
|
||||||
|
for i, wave_filename in enumerate(args.sound_files):
|
||||||
|
assert_file_exists(wave_filename)
|
||||||
|
samples, sample_rate = read_wave(wave_filename)
|
||||||
|
duration = len(samples) / sample_rate
|
||||||
|
total_duration += duration
|
||||||
|
s = recognizer.create_stream()
|
||||||
|
s.accept_waveform(sample_rate, samples)
|
||||||
|
|
||||||
|
streams.append(s)
|
||||||
|
if i % 10 == 0:
|
||||||
|
recognizer.decode_streams(streams)
|
||||||
|
results += [s.result.text for s in streams]
|
||||||
|
streams = []
|
||||||
|
print(f"Processed {i} files")
|
||||||
|
# process the last batch
|
||||||
|
if streams:
|
||||||
|
recognizer.decode_streams(streams)
|
||||||
|
results += [s.result.text for s in streams]
|
||||||
|
end_time = time.time()
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
results_dict = {}
|
||||||
|
for wave_filename, result in zip(args.sound_files, results):
|
||||||
|
print(f"{wave_filename}\n{result}")
|
||||||
|
print("-" * 10)
|
||||||
|
wave_basename = Path(wave_filename).stem
|
||||||
|
results_dict[wave_basename] = result
|
||||||
|
|
||||||
|
elapsed_seconds = end_time - start_time
|
||||||
|
rtf = elapsed_seconds / total_duration
|
||||||
|
print(f"num_threads: {args.num_threads}")
|
||||||
|
print(f"decoding_method: {args.decoding_method}")
|
||||||
|
print(f"Wave duration: {total_duration:.3f} s")
|
||||||
|
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||||
|
print(
|
||||||
|
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||||
|
)
|
||||||
|
if args.label:
|
||||||
|
from icefall.utils import store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
labels_dict = {}
|
||||||
|
with open(args.label, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
# fields = line.strip().split(" ")
|
||||||
|
# fields = [item for item in fields if item]
|
||||||
|
# assert len(fields) == 4
|
||||||
|
# prompt_text, prompt_audio, text, audio_path = fields
|
||||||
|
|
||||||
|
fields = line.strip().split("|")
|
||||||
|
fields = [item for item in fields if item]
|
||||||
|
assert len(fields) == 4
|
||||||
|
audio_path, prompt_text, prompt_audio, text = fields
|
||||||
|
labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
|
||||||
|
|
||||||
|
final_results = []
|
||||||
|
for key, value in results_dict.items():
|
||||||
|
final_results.append((key, labels_dict[key], value))
|
||||||
|
|
||||||
|
store_transcripts(
|
||||||
|
filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
|
||||||
|
)
|
||||||
|
with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
|
||||||
|
write_error_stats(f, "test-set", final_results, enable_log=True)
|
||||||
|
|
||||||
|
with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
|
||||||
|
print(f.readline()) # WER
|
||||||
|
print(f.readline()) # Detailed errors
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -98,3 +98,44 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
fi
|
fi
|
||||||
python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
|
python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
subset="Basic"
|
||||||
|
prefix="wenetspeech4tts"
|
||||||
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
|
log "Stage 5: Generate fbank (used by ./f5-tts)"
|
||||||
|
mkdir -p data/fbank
|
||||||
|
if [ ! -e data/fbank/.${prefix}.done ]; then
|
||||||
|
./local/compute_mel_feat.py --dataset-parts $subset --split 100
|
||||||
|
touch data/fbank/.${prefix}.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
|
log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)"
|
||||||
|
if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then
|
||||||
|
echo "Combining ${prefix} cuts"
|
||||||
|
pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz")
|
||||||
|
lhotse combine $pieces data/fbank/${prefix}_cuts_${subset}.jsonl.gz
|
||||||
|
fi
|
||||||
|
if [ ! -e data/fbank/.${prefix}_split.done ]; then
|
||||||
|
echo "Splitting ${prefix} cuts into train, valid and test sets"
|
||||||
|
|
||||||
|
lhotse subset --last 800 \
|
||||||
|
data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
|
||||||
|
data/fbank/${prefix}_cuts_validtest.jsonl.gz
|
||||||
|
lhotse subset --first 400 \
|
||||||
|
data/fbank/${prefix}_cuts_validtest.jsonl.gz \
|
||||||
|
data/fbank/${prefix}_cuts_valid.jsonl.gz
|
||||||
|
lhotse subset --last 400 \
|
||||||
|
data/fbank/${prefix}_cuts_validtest.jsonl.gz \
|
||||||
|
data/fbank/${prefix}_cuts_test.jsonl.gz
|
||||||
|
|
||||||
|
rm data/fbank/${prefix}_cuts_validtest.jsonl.gz
|
||||||
|
|
||||||
|
n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 ))
|
||||||
|
lhotse subset --first $n \
|
||||||
|
data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
|
||||||
|
data/fbank/${prefix}_cuts_train.jsonl.gz
|
||||||
|
touch data/fbank/.${prefix}_split.done
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
Loading…
x
Reference in New Issue
Block a user