Merge branch 'k2-fsa:master' into einichi

This commit is contained in:
Machiko Bailey 2025-01-27 18:13:57 -05:00 committed by GitHub
commit efc0536b6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 7166 additions and 145 deletions

View File

@ -1 +1 @@
../../../icefall/shared/
../../../icefall/shared

View File

@ -1165,23 +1165,34 @@ def train_one_epoch(
rank=rank,
)
if batch_idx % 100 == 0 and params.use_autocast:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
if params.use_autocast:
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
save_bad_model(suffix="-first-warning")
saved_bad_model = True
if not params.inf_check:
register_inf_check_hooks(model)
logging.warning(f"Grad scale is small: {cur_grad_scale}")
if cur_grad_scale < 1.0e-05:
save_bad_model()
raise_grad_scale_is_too_small_error(cur_grad_scale)
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
if (
batch_idx % 25 == 0
and cur_grad_scale < 2.0
or batch_idx % 100 == 0
and cur_grad_scale < 8.0
or batch_idx % 400 == 0
and cur_grad_scale < 32.0
):
scaler.update(cur_grad_scale * 2.0)
if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
@ -1335,7 +1346,7 @@ def run(rank, world_size, args):
clipping_scale=2.0,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=0.1)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")

View File

@ -17,6 +17,7 @@ class MatchaFbankConfig:
win_length: int
f_min: float
f_max: float
device: str = "cuda"
@register_extractor
@ -46,7 +47,7 @@ class MatchaFbank(FeatureExtractor):
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
samples = torch.from_numpy(samples).to(self.device)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape
@ -81,7 +82,7 @@ class MatchaFbank(FeatureExtractor):
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)
return mel.numpy()
return mel.cpu().numpy()
@property
def frame_shift(self) -> Seconds:

View File

@ -68,5 +68,69 @@ python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_
--text-extractor pypinyin_initials_finals --top-p ${top_p}
```
# [F5-TTS](https://arxiv.org/abs/2410.06885)
./f5-tts contains the code for training F5-TTS model.
Generated samples and training logs of wenetspeech basic 7k hours data can be found [here](https://huggingface.co/yuekai/f5-tts-small-wenetspeech4tts-basic/tensorboard).
Preparation:
```
bash prepare.sh --stage 5 --stop_stage 6
```
(Note: To compatiable with F5-TTS official checkpoint, we direclty use `vocab.txt` from [here.](https://github.com/SWivid/F5-TTS/blob/129014c5b43f135b0100d49a0c6804dd4cf673e1/data/Emilia_ZH_EN_pinyin/vocab.txt) To generate your own `vocab.txt`, you may refer to [the script](https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/train/datasets/prepare_emilia.py).)
The training command is given below:
```
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
world_size=8
exp_dir=exp/f5-tts-small
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size}
```
To inference with Icefall Wenetspeech4TTS trained F5-Small, use:
```
huggingface-cli login
huggingface-cli download --local-dir seed_tts_eval yuekai/seed_tts_eval --repo-type dataset
huggingface-cli download --local-dir ${exp_dir} yuekai/f5-tts-small-wenetspeech4tts-basic
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-56-avg-14.pt
# skip
python3 f5-tts/generate_averaged_model.py \
--epoch 56 \
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--exp-dir exp/f5_small
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
bash local/compute_wer.sh $output_dir $manifest
```
To inference with official Emilia trained F5-Base, use:
```
huggingface-cli login
huggingface-cli download --local-dir seed_tts_eval yuekai/seed_tts_eval --repo-type dataset
huggingface-cli download --local-dir F5-TTS SWivid/F5-TTS
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
manifest=./seed_tts_eval/seedtts_testset/zh/meta.lst
model_path=./F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir
bash local/compute_wer.sh $output_dir $manifest
```
# Credits
- [vall-e](https://github.com/lifeiteng/vall-e)
- [VALL-E](https://github.com/lifeiteng/vall-e)
- [F5-TTS](https://github.com/SWivid/F5-TTS)

View File

@ -0,0 +1,173 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
# Copyright 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) use the checkpoint exp_dir/epoch-xxx.pt
python3 bin/generate_averaged_model.py \
--epoch 40 \
--avg 5 \
--exp-dir ${exp_dir}
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
"""
import argparse
from pathlib import Path
import k2
import torch
from train import add_model_arguments, get_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
)
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = AttributeDict()
params.update(vars(args))
if params.iter > 0:
params.suffix = f"checkpoint-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
print("Script started")
device = torch.device("cpu")
print(f"Device: {device}")
print("About to create model")
filename = f"{params.exp_dir}/epoch-{params.epoch}.pt"
checkpoint = torch.load(filename, map_location=device)
args = AttributeDict(checkpoint)
model = get_model(args)
if params.iter > 0:
# TODO FIX ME
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
print(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
filename = params.exp_dir / f"checkpoint-{params.iter}-avg-{params.avg}.pt"
torch.save({"model": model.state_dict()}, filename)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
print(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
filenames = [
f"{params.exp_dir}/epoch-{i}.pt" for i in range(start, params.epoch + 1)
]
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
checkpoint["model"] = model.state_dict()
torch.save(checkpoint, filename)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
print("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,364 @@
#!/usr/bin/env python3
# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py
"""
Usage:
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst
python3 f5-tts/generate_averaged_model.py \
--epoch 56 \
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--exp-dir exp/f5_small
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
bash local/compute_wer.sh $output_dir $manifest
"""
import argparse
import logging
import math
import os
import random
import time
from pathlib import Path
import torch
import torch.nn.functional as F
import torchaudio
from accelerate import Accelerator
from bigvganinference import BigVGANInference
from model.cfm import CFM
from model.dit import DiT
from model.modules import MelSpec
from model.utils import convert_char_to_pinyin
from tqdm import tqdm
from train import (
add_model_arguments,
get_model,
get_tokenizer,
load_F5_TTS_pretrained_checkpoint,
)
from icefall.checkpoint import load_checkpoint
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
default="f5-tts/vocab.txt",
help="Path to the unique text tokens file",
)
parser.add_argument(
"--model-path",
type=str,
default="/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt",
help="Path to the unique text tokens file",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="The seed for random generators intended for reproducibility",
)
parser.add_argument(
"--nfe",
type=int,
default=16,
help="The number of steps for the neural ODE",
)
parser.add_argument(
"--manifest-file",
type=str,
default="/path/seed_tts_eval/seedtts_testset/zh/meta.lst",
help="The manifest file in seed_tts_eval format",
)
parser.add_argument(
"--output-dir",
type=Path,
default="results",
help="The output directory to save the generated wavs",
)
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
add_model_arguments(parser)
return parser.parse_args()
def get_inference_prompt(
metainfo,
speed=1.0,
tokenizer="pinyin",
polyphone=True,
target_sample_rate=24000,
n_fft=1024,
win_length=1024,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
num_buckets=200,
min_secs=3,
max_secs=40,
):
prompts_all = []
min_tokens = min_secs * target_sample_rate // hop_length
max_tokens = max_secs * target_sample_rate // hop_length
batch_accum = [0] * num_buckets
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
[[] for _ in range(num_buckets)] for _ in range(6)
)
mel_spectrogram = MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
metainfo, desc="Processing prompts..."
):
# Audio
ref_audio, ref_sr = torchaudio.load(prompt_wav)
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
if ref_rms < target_rms:
ref_audio = ref_audio * target_rms / ref_rms
assert (
ref_audio.shape[-1] > 5000
), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
if ref_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
ref_audio = resampler(ref_audio)
# Text
if len(prompt_text[-1].encode("utf-8")) == 1:
prompt_text = prompt_text + " "
text = [prompt_text + gt_text]
if tokenizer == "pinyin":
text_list = convert_char_to_pinyin(text, polyphone=polyphone)
else:
text_list = text
# Duration, mel frame length
ref_mel_len = ref_audio.shape[-1] // hop_length
if use_truth_duration:
gt_audio, gt_sr = torchaudio.load(gt_wav)
if gt_sr != target_sample_rate:
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
gt_audio = resampler(gt_audio)
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
# # test vocoder resynthesis
# ref_audio = gt_audio
else:
ref_text_len = len(prompt_text.encode("utf-8"))
gen_text_len = len(gt_text.encode("utf-8"))
total_mel_len = ref_mel_len + int(
ref_mel_len / ref_text_len * gen_text_len / speed
)
# to mel spectrogram
ref_mel = mel_spectrogram(ref_audio)
ref_mel = ref_mel.squeeze(0)
# deal with batch
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
assert (
min_tokens <= total_mel_len <= max_tokens
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
bucket_i = math.floor(
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
)
utts[bucket_i].append(utt)
ref_rms_list[bucket_i].append(ref_rms)
ref_mels[bucket_i].append(ref_mel)
ref_mel_lens[bucket_i].append(ref_mel_len)
total_mel_lens[bucket_i].append(total_mel_len)
final_text_list[bucket_i].extend(text_list)
batch_accum[bucket_i] += total_mel_len
if batch_accum[bucket_i] >= infer_batch_size:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
batch_accum[bucket_i] = 0
(
utts[bucket_i],
ref_rms_list[bucket_i],
ref_mels[bucket_i],
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
) = (
[],
[],
[],
[],
[],
[],
)
# add residual
for bucket_i, bucket_frames in enumerate(batch_accum):
if bucket_frames > 0:
prompts_all.append(
(
utts[bucket_i],
ref_rms_list[bucket_i],
padded_mel_batch(ref_mels[bucket_i]),
ref_mel_lens[bucket_i],
total_mel_lens[bucket_i],
final_text_list[bucket_i],
)
)
# not only leave easy work for last workers
random.seed(666)
random.shuffle(prompts_all)
return prompts_all
def padded_mel_batch(ref_mels):
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
padded_ref_mels = []
for mel in ref_mels:
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value=0)
padded_ref_mels.append(padded_ref_mel)
padded_ref_mels = torch.stack(padded_ref_mels)
padded_ref_mels = padded_ref_mels.permute(0, 2, 1)
return padded_ref_mels
def get_seedtts_testset_metainfo(metalst):
f = open(metalst)
lines = f.readlines()
f.close()
metainfo = []
for line in lines:
assert len(line.strip().split("|")) == 4
utt, prompt_text, prompt_wav, gt_text = line.strip().split("|")
utt = Path(utt).stem
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
if not os.path.isabs(prompt_wav):
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
return metainfo
def main():
args = get_parser()
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
prompts_all = get_inference_prompt(
metainfo,
speed=1.0,
tokenizer="pinyin",
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
mel_spec_type="bigvgan",
target_rms=0.1,
use_truth_duration=False,
infer_batch_size=1,
)
vocoder = BigVGANInference.from_pretrained(
"./bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
)
vocoder = vocoder.eval().to(device)
model = get_model(args).eval().to(device)
checkpoint = torch.load(args.model_path, map_location="cpu")
if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(model, args.model_path)
else:
_ = load_checkpoint(
args.model_path,
model=model,
)
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
start = time.time()
with accelerator.split_between_processes(prompts_all) as prompts:
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
(
utts,
ref_rms_list,
ref_mels,
ref_mel_lens,
total_mel_lens,
final_text_list,
) = prompt
ref_mels = ref_mels.to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
# Inference
with torch.inference_mode():
generated, _ = model.sample(
cond=ref_mels,
text=final_text_list,
duration=total_mel_lens,
lens=ref_mel_lens,
steps=args.nfe,
cfg_strength=2.0,
sway_sampling_coef=args.swaysampling,
no_ref_audio=False,
seed=args.seed,
)
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
target_rms = 0.1
target_sample_rate = 24_000
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(
f"{args.output_dir}/{utts[i]}.wav",
generated_wave,
target_sample_rate,
)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
timediff = time.time() - start
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,3 @@
# Introduction
Files in this folder are copied from
https://github.com/SWivid/F5-TTS/tree/main/src/f5_tts/model

View File

@ -0,0 +1,326 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from random import random
from typing import Callable
import torch
import torch.nn.functional as F
from model.modules import MelSpec
from model.utils import (
default,
exists,
lens_to_mask,
list_str_to_idx,
list_str_to_tensor,
mask_from_frac_lengths,
)
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torchdiffeq import odeint
class CFM(nn.Module):
def __init__(
self,
transformer: nn.Module,
sigma=0.0,
odeint_kwargs: dict = dict(
# atol = 1e-5,
# rtol = 1e-5,
method="euler" # 'midpoint'
),
audio_drop_prob=0.3,
cond_drop_prob=0.2,
num_channels=None,
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
vocab_char_map: dict[str:int] | None = None,
):
super().__init__()
self.frac_lengths_mask = frac_lengths_mask
# mel spec
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
num_channels = default(num_channels, self.mel_spec.n_mel_channels)
self.num_channels = num_channels
# classifier-free guidance
self.audio_drop_prob = audio_drop_prob
self.cond_drop_prob = cond_drop_prob
# transformer
self.transformer = transformer
dim = transformer.dim
self.dim = dim
# conditional flow related
self.sigma = sigma
# sampling related
self.odeint_kwargs = odeint_kwargs
# vocab map for tokenization
self.vocab_char_map = vocab_char_map
@property
def device(self):
return next(self.parameters()).device
@torch.no_grad()
def sample(
self,
cond: float["b n d"] | float["b nw"], # noqa: F722
text: int["b nt"] | list[str], # noqa: F722
duration: int | int["b"], # noqa: F821
*,
lens: int["b"] | None = None, # noqa: F821
steps=32,
cfg_strength=1.0,
sway_sampling_coef=None,
seed: int | None = None,
max_duration=4096,
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
no_ref_audio=False,
duplicate_test=False,
t_inter=0.1,
edit_mask=None,
):
self.eval()
# raw wave
if cond.ndim == 2:
cond = self.mel_spec(cond)
cond = cond.permute(0, 2, 1)
assert cond.shape[-1] == self.num_channels
cond = cond.to(next(self.parameters()).dtype)
batch, cond_seq_len, device = *cond.shape[:2], cond.device
if not exists(lens):
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
# text
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
if exists(text):
text_lens = (text != -1).sum(dim=-1)
lens = torch.maximum(
text_lens, lens
) # make sure lengths are at least those of the text characters
# duration
cond_mask = lens_to_mask(lens)
if edit_mask is not None:
cond_mask = cond_mask & edit_mask
if isinstance(duration, int):
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
duration = torch.maximum(
lens + 1, duration
) # just add one token so something is generated
duration = duration.clamp(max=max_duration)
max_duration = duration.amax()
# duplicate test corner for inner time step oberservation
if duplicate_test:
test_cond = F.pad(
cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0
)
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
cond_mask = F.pad(
cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
)
cond_mask = cond_mask.unsqueeze(-1)
step_cond = torch.where(
cond_mask, cond, torch.zeros_like(cond)
) # allow direct control (cut cond audio) with lens passed in
if batch > 1:
mask = lens_to_mask(duration)
else: # save memory and speed up, as single inference need no mask currently
mask = None
# test for no ref audio
if no_ref_audio:
cond = torch.zeros_like(cond)
# neural ode
def fn(t, x):
# at each step, conditioning is fixed
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
# predict flow
pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=False,
drop_text=False,
)
if cfg_strength < 1e-5:
return pred
null_pred = self.transformer(
x=x,
cond=step_cond,
text=text,
time=t,
mask=mask,
drop_audio_cond=True,
drop_text=True,
)
return pred + (pred - null_pred) * cfg_strength
# noise input
# to make sure batch inference result is same with different batch size, and for sure single inference
# still some difference maybe due to convolutional layers
y0 = []
for dur in duration:
if exists(seed):
torch.manual_seed(seed)
y0.append(
torch.randn(
dur, self.num_channels, device=self.device, dtype=step_cond.dtype
)
)
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
t_start = 0
# duplicate test corner for inner time step oberservation
if duplicate_test:
t_start = t_inter
y0 = (1 - t_start) * y0 + t_start * test_cond
steps = int(steps * (1 - t_start))
t = torch.linspace(
t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype
)
if sway_sampling_coef is not None:
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
sampled = trajectory[-1]
out = sampled
out = torch.where(cond_mask, cond, out)
if exists(vocoder):
out = out.permute(0, 2, 1)
out = vocoder(out)
return out, trajectory
def forward(
self,
inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
text: int["b nt"] | list[str], # noqa: F722
*,
lens: int["b"] | None = None, # noqa: F821
noise_scheduler: str | None = None,
):
# handle raw wave
if inp.ndim == 2:
inp = self.mel_spec(inp)
inp = inp.permute(0, 2, 1)
assert inp.shape[-1] == self.num_channels
batch, seq_len, dtype, device, _σ1 = (
*inp.shape[:2],
inp.dtype,
self.device,
self.sigma,
)
# handle text as string
if isinstance(text, list):
if exists(self.vocab_char_map):
text = list_str_to_idx(text, self.vocab_char_map).to(device)
else:
text = list_str_to_tensor(text).to(device)
assert text.shape[0] == batch
# lens and mask
if not exists(lens):
lens = torch.full((batch,), seq_len, device=device)
mask = lens_to_mask(
lens, length=seq_len
) # useless here, as collate_fn will pad to max length in batch
# get a random span to mask out for training conditionally
frac_lengths = (
torch.zeros((batch,), device=self.device)
.float()
.uniform_(*self.frac_lengths_mask)
)
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
if exists(mask):
rand_span_mask &= mask
# mel is x1
x1 = inp
# x0 is gaussian noise
x0 = torch.randn_like(x1)
# time step
time = torch.rand((batch,), dtype=dtype, device=self.device)
# TODO. noise_scheduler
# sample xt (φ_t(x) in the paper)
t = time.unsqueeze(-1).unsqueeze(-1)
φ = (1 - t) * x0 + t * x1
flow = x1 - x0
# only predict what is within the random mask span for infilling
cond = torch.where(rand_span_mask[..., None], torch.zeros_like(x1), x1)
# transformer and cfg training with a drop rate
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
if random() < self.cond_drop_prob: # p_uncond in voicebox paper
drop_audio_cond = True
drop_text = True
else:
drop_text = False
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
pred = self.transformer(
x=φ,
cond=cond,
text=text,
time=time,
drop_audio_cond=drop_audio_cond,
drop_text=drop_text,
)
# flow matching loss
loss = F.mse_loss(pred, flow, reduction="none")
loss = loss[rand_span_mask]
return loss.mean(), cond, pred

View File

@ -0,0 +1,210 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import torch
import torch.nn.functional as F
from model.modules import (
AdaLayerNormZero_Final,
ConvNeXtV2Block,
ConvPositionEmbedding,
DiTBlock,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
from torch import nn
from x_transformers.x_transformers import RotaryEmbedding
# Text embedding
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(
text_num_embeds + 1, text_dim
) # use 0 as filler token
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(text_dim, self.precompute_max_pos),
persistent=False,
)
self.text_blocks = nn.Sequential(
*[
ConvNeXtV2Block(text_dim, text_dim * conv_mult)
for _ in range(conv_layers)
]
)
else:
self.extra_modeling = False
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
text = (
text + 1
) # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[
:, :seq_len
] # curtail if character tokens are more than the mel spec tokens
batch, text_len = text.shape[0], text.shape[1]
text = F.pad(text, (0, seq_len - text_len), value=0)
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
pos_idx = get_pos_embed_indices(
batch_start, seq_len, max_pos=self.precompute_max_pos
)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
# convnextv2 blocks
text = self.text_blocks(text)
return text
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim):
super().__init__()
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
def forward(
self,
x: float["b n d"], # noqa: F722
cond: float["b n d"], # noqa: F722
text_embed: float["b n d"], # noqa: F722
drop_audio_cond=False,
):
if drop_audio_cond: # cfg for cond audio
cond = torch.zeros_like(cond)
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using DiT blocks
class DiT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=100,
text_num_embeds=256,
text_dim=None,
conv_layers=0,
long_skip_connection=False,
checkpoint_activations=False,
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
if text_dim is None:
text_dim = mel_dim
self.text_embed = TextEmbedding(
text_num_embeds, text_dim, conv_layers=conv_layers
)
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[
DiTBlock(
dim=dim,
heads=heads,
dim_head=dim_head,
ff_mult=ff_mult,
dropout=dropout,
)
for _ in range(depth)
]
)
self.long_skip_connection = (
nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
)
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.checkpoint_activations = checkpoint_activations
def ckpt_wrapper(self, module):
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
def forward(
self,
x: float["b n d"], # nosied input audio # noqa: F722
cond: float["b n d"], # masked cond audio # noqa: F722
text: int["b nt"], # text # noqa: F722
time: float["b"] | float[""], # time step # noqa: F821 F722
drop_audio_cond, # cfg for cond audio
drop_text, # cfg for text
mask: bool["b n"] | None = None, # noqa: F722
):
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
time = time.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
rope = self.rotary_embed.forward_from_seq_len(seq_len)
if self.long_skip_connection is not None:
residual = x
for block in self.transformer_blocks:
if self.checkpoint_activations:
x = torch.utils.checkpoint.checkpoint(
self.ckpt_wrapper(block), x, t, mask, rope
)
else:
x = block(x, t, mask=mask, rope=rope)
if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
x = self.norm_out(x, t)
output = self.proj_out(x)
return output

View File

@ -0,0 +1,728 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import math
from typing import Optional
import torch
import torch.nn.functional as F
import torchaudio
from librosa.filters import mel as librosa_mel_fn
from torch import nn
from x_transformers.x_transformers import apply_rotary_pos_emb
# raw wav to mel spec
mel_basis_cache = {}
hann_window_cache = {}
def get_bigvgan_mel_spectrogram(
waveform,
n_fft=1024,
n_mel_channels=100,
target_sample_rate=24000,
hop_length=256,
win_length=1024,
fmin=0,
fmax=None,
center=False,
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
device = waveform.device
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(
sr=target_sample_rate,
n_fft=n_fft,
n_mels=n_mel_channels,
fmin=fmin,
fmax=fmax,
)
mel_basis_cache[key] = (
torch.from_numpy(mel).float().to(device)
) # TODO: why they need .float()?
hann_window_cache[key] = torch.hann_window(win_length).to(device)
mel_basis = mel_basis_cache[key]
hann_window = hann_window_cache[key]
padding = (n_fft - hop_length) // 2
waveform = torch.nn.functional.pad(
waveform.unsqueeze(1), (padding, padding), mode="reflect"
).squeeze(1)
spec = torch.stft(
waveform,
n_fft,
hop_length=hop_length,
win_length=win_length,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
mel_spec = torch.matmul(mel_basis, spec)
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
return mel_spec
def get_vocos_mel_spectrogram(
waveform,
n_fft=1024,
n_mel_channels=100,
target_sample_rate=24000,
hop_length=256,
win_length=1024,
):
mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=1,
center=True,
normalized=False,
norm=None,
).to(waveform.device)
if len(waveform.shape) == 3:
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
assert len(waveform.shape) == 2
mel = mel_stft(waveform)
mel = mel.clamp(min=1e-5).log()
return mel
class MelSpec(nn.Module):
def __init__(
self,
n_fft=1024,
hop_length=256,
win_length=1024,
n_mel_channels=100,
target_sample_rate=24_000,
mel_spec_type="vocos",
):
super().__init__()
assert mel_spec_type in ["vocos", "bigvgan"], print(
"We only support two extract mel backend: vocos or bigvgan"
)
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.n_mel_channels = n_mel_channels
self.target_sample_rate = target_sample_rate
if mel_spec_type == "vocos":
self.extractor = get_vocos_mel_spectrogram
elif mel_spec_type == "bigvgan":
self.extractor = get_bigvgan_mel_spectrogram
self.register_buffer("dummy", torch.tensor(0), persistent=False)
def forward(self, wav):
if self.dummy.device != wav.device:
self.to(wav.device)
mel = self.extractor(
waveform=wav,
n_fft=self.n_fft,
n_mel_channels=self.n_mel_channels,
target_sample_rate=self.target_sample_rate,
hop_length=self.hop_length,
win_length=self.win_length,
)
return mel
# sinusoidal position embedding
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# convolutional position embedding
class ConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = self.conv1d(x)
out = x.permute(0, 2, 1)
if mask is not None:
out = out.masked_fill(~mask, 0.0)
return out
# rotary positional embedding related
def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0
):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
# length = length if isinstance(length, int) else length.max()
scale = scale * torch.ones_like(
start, dtype=torch.float32
) # in case scale is a scalar
pos = (
start.unsqueeze(1)
+ (
torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0)
* scale.unsqueeze(1)
).long()
)
# avoid extra long error.
pos = torch.where(pos < max_pos, pos, max_pos - 1)
return pos
# Global Response Normalization layer (Instance Normalization ?)
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, intermediate_dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
# AdaLayerNormZero
# return with modulated x for attn input, and params for later mlp modulation
class AdaLayerNormZero(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 6)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(
emb, 6, dim=1
)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
# AdaLayerNormZero for final layer
# return only with modulated x for attn input, cuz no more mlp modulation
class AdaLayerNormZero_Final(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 2)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
# FeedForward
class FeedForward(nn.Module):
def __init__(
self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
activation = nn.GELU(approximate=approximate)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
self.ff = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.ff(x)
# Attention with possible joint part
# modified from diffusers/src/diffusers/models/attention_processor.py
class Attention(nn.Module):
def __init__(
self,
processor: JointAttnProcessor | AttnProcessor,
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.processor = processor
self.dim = dim
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.to_q = nn.Linear(dim, self.inner_dim)
self.to_k = nn.Linear(dim, self.inner_dim)
self.to_v = nn.Linear(dim, self.inner_dim)
if self.context_dim is not None:
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, dim))
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, dim)
def forward(
self,
x: float["b n d"], # noised input x # noqa: F722
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
else:
return self.processor(self, x, mask=mask, rope=rope)
# Attention processor
class AttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (
(xpos_scale, xpos_scale**-1.0)
if xpos_scale is not None
else (1.0, 1.0)
)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(
batch_size, attn.heads, query.shape[-2], key.shape[-2]
)
else:
attn_mask = None
x = F.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
return x
# Joint Attention processor for MM-DiT
# modified from diffusers/src/diffusers/models/attention_processor.py
class JointAttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
c: float["b nt d"] = None, # context c, here text # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
residual = x
batch_size = c.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# `context` projections.
c_query = attn.to_q_c(c)
c_key = attn.to_k_c(c)
c_value = attn.to_v_c(c)
# apply rope for context and noised input independently
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (
(xpos_scale, xpos_scale**-1.0)
if xpos_scale is not None
else (1.0, 1.0)
)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
if c_rope is not None:
freqs, xpos_scale = c_rope
q_xpos_scale, k_xpos_scale = (
(xpos_scale, xpos_scale**-1.0)
if xpos_scale is not None
else (1.0, 1.0)
)
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
# attention
query = torch.cat([query, c_query], dim=1)
key = torch.cat([key, c_key], dim=1)
value = torch.cat([value, c_value], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(
batch_size, attn.heads, query.shape[-2], key.shape[-2]
)
else:
attn_mask = None
x = F.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# Split the attention outputs.
x, c = (
x[:, : residual.shape[1]],
x[:, residual.shape[1] :],
)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if not attn.context_pre_only:
c = attn.to_out_c(c)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
return x, c
# DiT Block
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
)
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
attn_output = self.attn(x=norm, mask=mask, rope=rope)
# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm)
x = x + gate_mlp.unsqueeze(1) * ff_output
return x
# MMDiT Block https://arxiv.org/abs/2403.03206
class MMDiTBlock(nn.Module):
r"""
modified from diffusers/src/diffusers/models/attention.py
notes.
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
_x: noised input related. (right part)
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
"""
def __init__(
self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False
):
super().__init__()
self.context_pre_only = context_pre_only
self.attn_norm_c = (
AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
)
self.attn_norm_x = AdaLayerNormZero(dim)
self.attn = Attention(
processor=JointAttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
context_dim=dim,
context_pre_only=context_pre_only,
)
if not context_pre_only:
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
)
else:
self.ff_norm_c = None
self.ff_c = None
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_x = FeedForward(
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
)
def forward(
self, x, c, t, mask=None, rope=None, c_rope=None
): # x: noised input, c: context, t: time embedding
# pre-norm & modulation for attention input
if self.context_pre_only:
norm_c = self.attn_norm_c(c, t)
else:
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(
c, emb=t
)
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(
x, emb=t
)
# attention
x_attn_output, c_attn_output = self.attn(
x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope
)
# process attention output for context c
if self.context_pre_only:
c = None
else: # if not last layer
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
norm_c = (
self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
)
c_ff_output = self.ff_c(norm_c)
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
# process attention output for input x
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
x_ff_output = self.ff_x(norm_x)
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
return c, x
# time step conditioning embedding
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(
nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
)
def forward(self, timestep: float["b"]): # noqa: F821
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d
return time

View File

@ -0,0 +1,206 @@
from __future__ import annotations
import os
import random
from collections import defaultdict
from importlib.resources import files
import jieba
import torch
from pypinyin import Style, lazy_pinyin
from torch.nn.utils.rnn import pad_sequence
# seed everything
def seed_everything(seed=0):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# helpers
def exists(v):
return v is not None
def default(v, d):
return v if exists(v) else d
# tensor helpers
def lens_to_mask(
t: int["b"], length: int | None = None # noqa: F722 F821
) -> bool["b n"]: # noqa: F722 F821
if not exists(length):
length = t.amax()
seq = torch.arange(length, device=t.device)
return seq[None, :] < t[:, None]
def mask_from_start_end_indices(
seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821
):
max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device=start.device).long()
start_mask = seq[None, :] >= start[:, None]
end_mask = seq[None, :] < end[:, None]
return start_mask & end_mask
def mask_from_frac_lengths(
seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821
):
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
rand = torch.rand_like(frac_lengths)
start = (max_start * rand).long().clamp(min=0)
end = start + lengths
return mask_from_start_end_indices(seq_len, start, end)
def maybe_masked_mean(
t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821
) -> float["b d"]: # noqa: F722 F821
if not exists(mask):
return t.mean(dim=1)
t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
num = t.sum(dim=1)
den = mask.float().sum(dim=1)
return num / den.clamp(min=1.0)
# simple utf-8 tokenizer, since paper went character based
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
return text
# char tokenizer, based on custom dataset's extracted .txt file
def list_str_to_idx(
text: list[str] | list[list[str]],
vocab_char_map: dict[str, int], # {char: idx}
padding_value=-1,
) -> int["b nt"]: # noqa: F722
list_idx_tensors = [
torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text
] # pinyin or char style
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
return text
# Get tokenizer
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
"""
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
- "char" for char-wise tokenizer, need .txt vocab_file
- "byte" for utf-8 tokenizer
- "custom" if you're directly passing in a path to the vocab.txt you want to use
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
- if use "char", derived from unfiltered character & symbol counts of custom dataset
- if use "byte", set to 256 (unicode byte range)
"""
if tokenizer in ["pinyin", "char"]:
tokenizer_path = os.path.join(
files("f5_tts").joinpath("../../data"),
f"{dataset_name}_{tokenizer}/vocab.txt",
)
with open(tokenizer_path, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
assert (
vocab_char_map[" "] == 0
), "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
elif tokenizer == "byte":
vocab_char_map = None
vocab_size = 256
elif tokenizer == "custom":
with open(dataset_name, "r", encoding="utf-8") as f:
vocab_char_map = {}
for i, char in enumerate(f):
vocab_char_map[char[:-1]] = i
vocab_size = len(vocab_char_map)
return vocab_char_map, vocab_size
# convert char to pinyin
jieba.initialize()
print("Word segmentation module jieba initialized.\n")
def convert_char_to_pinyin(text_list, polyphone=True):
final_text_list = []
custom_trans = str.maketrans(
{";": ",", "": '"', "": '"', "": "'", "": "'"}
) # add custom trans here, to address oov
def is_chinese(c):
return "\u3100" <= c <= "\u9fff" # common chinese characters
for text in text_list:
char_list = []
text = text.translate(custom_trans)
for seg in jieba.cut(text):
seg_byte_len = len(bytes(seg, "UTF-8"))
if seg_byte_len == len(seg): # if pure alphabets and symbols
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
char_list.append(" ")
char_list.extend(seg)
elif polyphone and seg_byte_len == 3 * len(
seg
): # if pure east asian characters
seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
for i, c in enumerate(seg):
if is_chinese(c):
char_list.append(" ")
char_list.append(seg_[i])
else: # if mixed characters, alphabets and symbols
for c in seg:
if ord(c) < 256:
char_list.extend(c)
elif is_chinese(c):
char_list.append(" ")
char_list.extend(
lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True)
)
else:
char_list.append(c)
final_text_list.append(char_list)
return final_text_list
# filter func for dirty data with many repetitions
def repetition_found(text, length=2, tolerance=10):
pattern_count = defaultdict(int)
for i in range(len(text) - length + 1):
pattern = text[i : i + length]
pattern_count[pattern] += 1
for pattern, count in pattern_count.items():
if count > tolerance:
return True
return False

View File

@ -0,0 +1,104 @@
from typing import Callable, Dict, List, Sequence, Union
import torch
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.collation import collate_audio
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import ifnone
class SpeechSynthesisDataset(torch.utils.data.Dataset):
"""
The PyTorch Dataset for the speech synthesis task.
Each item in this dataset is a dict of:
.. code-block::
{
'audio': (B x NumSamples) float tensor
'features': (B x NumFrames x NumFeatures) float tensor
'audio_lens': (B, ) int tensor
'features_lens': (B, ) int tensor
'text': List[str] of len B # when return_text=True
'tokens': List[List[str]] # when return_tokens=True
'speakers': List[str] of len B # when return_spk_ids=True
'cut': List of Cuts # when return_cuts=True
}
"""
def __init__(
self,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
feature_input_strategy: BatchIO = PrecomputedFeatures(),
feature_transforms: Union[Sequence[Callable], Callable] = None,
return_text: bool = True,
return_tokens: bool = False,
return_spk_ids: bool = False,
return_cuts: bool = False,
) -> None:
super().__init__()
self.cut_transforms = ifnone(cut_transforms, [])
self.feature_input_strategy = feature_input_strategy
self.return_text = return_text
self.return_tokens = return_tokens
self.return_spk_ids = return_spk_ids
self.return_cuts = return_cuts
if feature_transforms is None:
feature_transforms = []
elif not isinstance(feature_transforms, Sequence):
feature_transforms = [feature_transforms]
assert all(
isinstance(transform, Callable) for transform in feature_transforms
), "Feature transforms must be Callable"
self.feature_transforms = feature_transforms
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
validate_for_tts(cuts)
for transform in self.cut_transforms:
cuts = transform(cuts)
# audio, audio_lens = collate_audio(cuts)
features, features_lens = self.feature_input_strategy(cuts)
for transform in self.feature_transforms:
features = transform(features)
batch = {
# "audio": audio,
"features": features,
# "audio_lens": audio_lens,
"features_lens": features_lens,
}
if self.return_text:
# use normalized text
# text = [cut.supervisions[0].normalized_text for cut in cuts]
text = [cut.supervisions[0].text for cut in cuts]
batch["text"] = text
if self.return_tokens:
# tokens = [cut.tokens for cut in cuts]
tokens = [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts]
batch["tokens"] = tokens
if self.return_spk_ids:
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
if self.return_cuts:
batch["cut"] = [cut for cut in cuts]
return batch
def validate_for_tts(cuts: CutSet) -> None:
validate(cuts)
for cut in cuts:
assert (
len(cut.supervisions) == 1
), "Only the Cuts with single supervision are supported."

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,306 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset,
CutConcatenate,
CutMix,
DynamicBucketingSampler,
PrecomputedFeatures,
SimpleCutSampler,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from speech_synthesis import SpeechSynthesisDataset # noqa F401
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class TtsDataModule:
"""
DataModule for tts experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="TTS data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
parser.add_argument(
"--prefix",
type=str,
default="wenetspeech4tts",
help="prefix of the manifest file",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
logging.info("About to create train dataset")
train = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet."
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=True,
pin_memory=True,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet."
)
else:
validate = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create valid dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=True,
pin_memory=True,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
if self.args.on_the_fly_feats:
raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet."
)
else:
test = SpeechSynthesisDataset(
return_text=True,
return_tokens=False,
feature_input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
shuffle=False,
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=test_sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / f"{self.args.prefix}_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get validation cuts")
return load_manifest_lazy(
self.args.manifest_dir / f"{self.args.prefix}_cuts_valid.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / f"{self.args.prefix}_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/utils.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,122 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import math
import os
import pathlib
import random
from typing import List, Optional, Tuple
import librosa
import numpy as np
import torch
import torch.utils.data
from librosa.filters import mel as librosa_mel_fn
from tqdm import tqdm
# from env import AttrDict
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
return dynamic_range_compression_torch(magnitudes)
def spectral_de_normalize_torch(magnitudes):
return dynamic_range_decompression_torch(magnitudes)
mel_basis_cache = {}
hann_window_cache = {}
def mel_spectrogram(
y: torch.Tensor,
n_fft: int = 1024,
num_mels: int = 100,
sampling_rate: int = 24_000,
hop_size: int = 256,
win_size: int = 1024,
fmin: int = 0,
fmax: int = None,
center: bool = False,
) -> torch.Tensor:
"""
Calculate the mel spectrogram of an input signal.
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
Args:
y (torch.Tensor): Input signal.
n_fft (int): FFT size.
num_mels (int): Number of mel bins.
sampling_rate (int): Sampling rate of the input signal.
hop_size (int): Hop size for STFT.
win_size (int): Window size for STFT.
fmin (int): Minimum frequency for mel filterbank.
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
center (bool): Whether to pad the input to center the frames. Default is False.
Returns:
torch.Tensor: Mel spectrogram.
"""
if torch.min(y) < -1.0:
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
if torch.max(y) > 1.0:
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
device = y.device
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
hann_window_cache[key] = torch.hann_window(win_size).to(device)
mel_basis = mel_basis_cache[key]
hann_window = hann_window_cache[key]
padding = (n_fft - hop_size) // 2
y = torch.nn.functional.pad(
y.unsqueeze(1), (padding, padding), mode="reflect"
).squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
mel_spec = torch.matmul(mel_basis, spec)
mel_spec = spectral_normalize_torch(mel_spec)
return mel_spec

View File

@ -0,0 +1,218 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the LJSpeech dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-jobs",
type=int,
default=1,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)
parser.add_argument(
"--src-dir",
type=Path,
default=Path("data/manifests"),
help="Path to the manifest files",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("data/fbank"),
help="Path to the tokenized files",
)
parser.add_argument(
"--dataset-parts",
type=str,
default="Basic",
help="Space separated dataset parts",
)
parser.add_argument(
"--prefix",
type=str,
default="wenetspeech4tts",
help="prefix of the manifest file",
)
parser.add_argument(
"--suffix",
type=str,
default="jsonl.gz",
help="suffix of the manifest file",
)
parser.add_argument(
"--split",
type=int,
default=100,
help="Split the cut_set into multiple parts",
)
parser.add_argument(
"--resample-to-24kHz",
default=True,
help="Resample the audio to 24kHz",
)
parser.add_argument(
"--extractor",
type=str,
choices=["bigvgan", "hifigan"],
default="bigvgan",
help="The type of extractor to use",
)
return parser
def compute_fbank(args):
src_dir = Path(args.src_dir)
output_dir = Path(args.output_dir)
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
num_jobs = min(args.num_jobs, os.cpu_count())
dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip().split(" ")
logging.info(f"num_jobs: {num_jobs}")
logging.info(f"src_dir: {src_dir}")
logging.info(f"output_dir: {output_dir}")
logging.info(f"dataset_parts: {dataset_parts}")
if args.extractor == "bigvgan":
config = MatchaFbankConfig(
n_fft=1024,
n_mels=100,
sampling_rate=24_000,
hop_length=256,
win_length=1024,
f_min=0,
f_max=None,
)
elif args.extractor == "hifigan":
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=22050,
hop_length=256,
win_length=1024,
f_min=0,
f_max=8000,
)
else:
raise NotImplementedError(f"Extractor {args.extractor} is not implemented")
extractor = MatchaFbank(config)
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=args.src_dir,
prefix=args.prefix,
suffix=args.suffix,
types=["recordings", "supervisions", "cuts"],
)
with get_executor() as ex:
for partition, m in manifests.items():
logging.info(
f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}"
)
try:
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
except Exception:
cut_set = m["cuts"]
if args.split > 1:
cut_sets = cut_set.split(args.split)
else:
cut_sets = [cut_set]
for idx, part in enumerate(cut_sets):
if args.split > 1:
storage_path = f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}_{idx}"
else:
storage_path = (
f"{args.output_dir}/{args.prefix}_{args.extractor}_{partition}"
)
if args.resample_to_24kHz:
part = part.resample(24000)
with torch.no_grad():
part = part.compute_and_store_features(
extractor=extractor,
storage_path=storage_path,
num_jobs=num_jobs if ex is None else 64,
executor=ex,
storage_type=LilcomChunkyWriter,
)
if args.split > 1:
cuts_filename = (
f"{args.prefix}_cuts_{partition}.{idx}.{args.suffix}"
)
else:
cuts_filename = f"{args.prefix}_cuts_{partition}.{args.suffix}"
part.to_file(f"{args.output_dir}/{cuts_filename}")
logging.info(f"Saved {cuts_filename}")
if __name__ == "__main__":
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_parser().parse_args()
compute_fbank(args)

View File

@ -0,0 +1,26 @@
wav_dir=$1
wav_files=$(ls $wav_dir/*.wav)
# if wav_files is empty, then exit
if [ -z "$wav_files" ]; then
exit 1
fi
label_file=$2
model_path=local/sherpa-onnx-paraformer-zh-2023-09-14
if [ ! -d $model_path ]; then
pip install sherpa-onnx
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C local
fi
python3 local/offline-decode-files.py \
--tokens=$model_path/tokens.txt \
--paraformer=$model_path/model.int8.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
--sample-rate=24000 \
--log-dir $wav_dir \
--feature-dim=80 \
--label $label_file \
$wav_files

View File

@ -0,0 +1 @@
../../../ljspeech/TTS/matcha/fbank.py

View File

@ -0,0 +1,495 @@
#!/usr/bin/env python3
#
# Copyright (c) 2023 by manyeyes
# Copyright (c) 2023 Xiaomi Corporation
"""
This file demonstrates how to use sherpa-onnx Python API to transcribe
file(s) with a non-streaming model.
(1) For paraformer
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--paraformer=/path/to/paraformer.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
--sample-rate=16000 \
--feature-dim=80 \
/path/to/0.wav \
/path/to/1.wav
(2) For transducer models from icefall
./python-api-examples/offline-decode-files.py \
--tokens=/path/to/tokens.txt \
--encoder=/path/to/encoder.onnx \
--decoder=/path/to/decoder.onnx \
--joiner=/path/to/joiner.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
--sample-rate=16000 \
--feature-dim=80 \
/path/to/0.wav \
/path/to/1.wav
(3) For CTC models from NeMo
python3 ./python-api-examples/offline-decode-files.py \
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
(4) For Whisper models
python3 ./python-api-examples/offline-decode-files.py \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--whisper-task=transcribe \
--num-threads=1 \
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
(5) For CTC models from WeNet
python3 ./python-api-examples/offline-decode-files.py \
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
(6) For tdnn models of the yesno recipe from icefall
python3 ./python-api-examples/offline-decode-files.py \
--sample-rate=8000 \
--feature-dim=23 \
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download non-streaming pre-trained models
used in this file.
"""
import argparse
import time
import wave
from pathlib import Path
from typing import List, Tuple
import numpy as np
import sherpa_onnx
import soundfile as sf
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--tokens",
type=str,
help="Path to tokens.txt",
)
parser.add_argument(
"--hotwords-file",
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, like
HELLO WORLD
你好世界
""",
)
parser.add_argument(
"--hotwords-score",
type=float,
default=1.5,
help="""
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)
parser.add_argument(
"--modeling-unit",
type=str,
default="",
help="""
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
Used only when hotwords-file is given.
""",
)
parser.add_argument(
"--bpe-vocab",
type=str,
default="",
help="""
The path to the bpe vocabulary, the bpe vocabulary is generated by
sentencepiece, you can also export the bpe vocabulary through a bpe model
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
and modeling-unit is bpe or cjkchar+bpe.
""",
)
parser.add_argument(
"--encoder",
default="",
type=str,
help="Path to the encoder model",
)
parser.add_argument(
"--decoder",
default="",
type=str,
help="Path to the decoder model",
)
parser.add_argument(
"--joiner",
default="",
type=str,
help="Path to the joiner model",
)
parser.add_argument(
"--paraformer",
default="",
type=str,
help="Path to the model.onnx from Paraformer",
)
parser.add_argument(
"--nemo-ctc",
default="",
type=str,
help="Path to the model.onnx from NeMo CTC",
)
parser.add_argument(
"--wenet-ctc",
default="",
type=str,
help="Path to the model.onnx from WeNet CTC",
)
parser.add_argument(
"--tdnn-model",
default="",
type=str,
help="Path to the model.onnx for the tdnn model of the yesno recipe",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)
parser.add_argument(
"--whisper-encoder",
default="",
type=str,
help="Path to whisper encoder model",
)
parser.add_argument(
"--whisper-decoder",
default="",
type=str,
help="Path to whisper decoder model",
)
parser.add_argument(
"--whisper-language",
default="",
type=str,
help="""It specifies the spoken language in the input audio file.
Example values: en, fr, de, zh, jp.
Available languages for multilingual models can be found at
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
If not specified, we infer the language from the input audio file.
""",
)
parser.add_argument(
"--whisper-task",
default="transcribe",
choices=["transcribe", "translate"],
type=str,
help="""For multilingual models, if you specify translate, the output
will be in English.
""",
)
parser.add_argument(
"--whisper-tail-paddings",
default=-1,
type=int,
help="""Number of tail padding frames.
We have removed the 30-second constraint from whisper, so you need to
choose the amount of tail padding frames by yourself.
Use -1 to use a default value for tail padding.
""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
parser.add_argument(
"--debug",
type=bool,
default=False,
help="True to show debug messages",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="""Sample rate of the feature extractor. Must match the one
expected by the model. Note: The input sound files can have a
different sample rate from this argument.""",
)
parser.add_argument(
"--feature-dim",
type=int,
default=80,
help="Feature dimension. Must match the one expected by the model",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to decode. Each file must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
)
parser.add_argument(
"--name",
type=str,
default="",
help="The directory containing the input sound files to decode",
)
parser.add_argument(
"--log-dir",
type=str,
default="",
help="The directory containing the input sound files to decode",
)
parser.add_argument(
"--label",
type=str,
default=None,
help="wav_base_name label",
)
return parser.parse_args()
def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and can be of type
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples,
which are normalized to the range [-1, 1].
- Sample rate of the wave file.
"""
samples, sample_rate = sf.read(wave_filename, dtype="float32")
assert (
samples.ndim == 1
), f"Expected single channel, but got {samples.ndim} channels."
samples_float32 = samples.astype(np.float32)
return samples_float32, sample_rate
def normalize_text_alimeeting(text: str) -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
def main():
args = get_args()
assert_file_exists(args.tokens)
assert args.num_threads > 0, args.num_threads
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.wenet_ctc) == 0, args.wenet_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder
assert len(args.tdnn_model) == 0, args.tdnn_model
assert_file_exists(args.paraformer)
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
paraformer=args.paraformer,
tokens=args.tokens,
num_threads=args.num_threads,
sample_rate=args.sample_rate,
feature_dim=args.feature_dim,
decoding_method=args.decoding_method,
debug=args.debug,
)
print("Started!")
start_time = time.time()
streams, results = [], []
total_duration = 0
for i, wave_filename in enumerate(args.sound_files):
assert_file_exists(wave_filename)
samples, sample_rate = read_wave(wave_filename)
duration = len(samples) / sample_rate
total_duration += duration
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
streams.append(s)
if i % 10 == 0:
recognizer.decode_streams(streams)
results += [s.result.text for s in streams]
streams = []
print(f"Processed {i} files")
# process the last batch
if streams:
recognizer.decode_streams(streams)
results += [s.result.text for s in streams]
end_time = time.time()
print("Done!")
results_dict = {}
for wave_filename, result in zip(args.sound_files, results):
print(f"{wave_filename}\n{result}")
print("-" * 10)
wave_basename = Path(wave_filename).stem
results_dict[wave_basename] = result
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
print(f"num_threads: {args.num_threads}")
print(f"decoding_method: {args.decoding_method}")
print(f"Wave duration: {total_duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
if args.label:
from icefall.utils import store_transcripts, write_error_stats
labels_dict = {}
with open(args.label, "r") as f:
for line in f:
# fields = line.strip().split(" ")
# fields = [item for item in fields if item]
# assert len(fields) == 4
# prompt_text, prompt_audio, text, audio_path = fields
fields = line.strip().split("|")
fields = [item for item in fields if item]
assert len(fields) == 4
audio_path, prompt_text, prompt_audio, text = fields
labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
final_results = []
for key, value in results_dict.items():
final_results.append((key, labels_dict[key], value))
store_transcripts(
filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
)
with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
write_error_stats(f, "test-set", final_results, enable_log=True)
with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
print(f.readline()) # WER
print(f.readline()) # Detailed errors
if __name__ == "__main__":
main()

View File

@ -98,3 +98,44 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
fi
subset="Basic"
prefix="wenetspeech4tts"
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate fbank (used by ./f5-tts)"
mkdir -p data/fbank
if [ ! -e data/fbank/.${prefix}.done ]; then
./local/compute_mel_feat.py --dataset-parts $subset --split 100
touch data/fbank/.${prefix}.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)"
if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then
echo "Combining ${prefix} cuts"
pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz")
lhotse combine $pieces data/fbank/${prefix}_cuts_${subset}.jsonl.gz
fi
if [ ! -e data/fbank/.${prefix}_split.done ]; then
echo "Splitting ${prefix} cuts into train, valid and test sets"
lhotse subset --last 800 \
data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
data/fbank/${prefix}_cuts_validtest.jsonl.gz
lhotse subset --first 400 \
data/fbank/${prefix}_cuts_validtest.jsonl.gz \
data/fbank/${prefix}_cuts_valid.jsonl.gz
lhotse subset --last 400 \
data/fbank/${prefix}_cuts_validtest.jsonl.gz \
data/fbank/${prefix}_cuts_test.jsonl.gz
rm data/fbank/${prefix}_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 ))
lhotse subset --first $n \
data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
data/fbank/${prefix}_cuts_train.jsonl.gz
touch data/fbank/.${prefix}_split.done
fi
fi

View File

@ -118,13 +118,6 @@ def get_args():
help="The temperature of AR Decoder top_k sampling.",
)
parser.add_argument(
"--continual",
type=str2bool,
default=False,
help="Do continual task.",
)
parser.add_argument(
"--repetition-aware-sampling",
type=str2bool,
@ -262,29 +255,21 @@ def main():
)
# synthesis
if args.continual:
assert text == ""
encoded_frames = model.continual(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
)
else:
enroll_x_lens = None
if text_prompts:
_, enroll_x_lens = text_collater(
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
)
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=args.top_k,
temperature=args.temperature,
top_p=args.top_p,
ras=args.repetition_aware_sampling,
enroll_x_lens = None
if text_prompts:
_, enroll_x_lens = text_collater(
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
)
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=args.top_k,
temperature=args.temperature,
top_p=args.top_p,
ras=args.repetition_aware_sampling,
)
if audio_prompts != []:
samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])

View File

@ -1564,103 +1564,6 @@ class VALLE(nn.Module):
assert len(codes) == self.num_quantizers
return torch.stack(codes, dim=-1)
def continual(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor of shape (1, S).
x_lens:
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
before padding.
y:
A 3-D tensor of shape (1, T, 8).
Returns:
Return the predicted audio code matrix.
"""
assert x.ndim == 2, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.ndim == 3, y.shape
assert y.shape[0] == 1, y.shape
assert torch.all(x_lens > 0)
assert self.num_quantizers == 8
# NOTE: x has been padded in TextTokenCollater
text = x
x = self.ar_text_embedding(text)
x = self.ar_text_prenet(x)
x = self.ar_text_position(x)
text_len = x_lens.max()
prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
# AR Decoder
prompts = y[:, :prefix_len]
codes = [y[:, prefix_len:, 0]]
# Non-AR Decoders
x = self.nar_text_embedding(text)
x = self.nar_text_prenet(x)
x = self.nar_text_position(x)
y_emb = self.nar_audio_embeddings[0](y[..., 0])
if self.prefix_mode == 0:
for i, (predict_layer, embedding_layer) in enumerate(
zip(
self.nar_predict_layers,
self.nar_audio_embeddings[1:],
)
):
y_pos = self.nar_audio_position(y_emb)
y_pos = self.nar_audio_prenet(y_pos)
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.nar_decoder(
(xy_pos, self.nar_stage_embeddings[i].weight)
)
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
samples = torch.argmax(logits, dim=-1)
codes.append(samples)
if i < 6:
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
y_emb[:, prefix_len:] += embedding_layer(samples)
else:
for j in range(1, 8):
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
for i, (predict_layer, embedding_layer) in enumerate(
zip(
self.nar_predict_layers,
self.nar_audio_embeddings[1:],
)
):
y_pos = self.nar_audio_prenet(y_emb)
y_pos = self.nar_audio_position(y_pos)
xy_pos = torch.concat([x, y_pos], dim=1)
xy_dec, _ = self.nar_decoder(
(xy_pos, self.nar_stage_embeddings[i].weight)
)
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
samples = torch.argmax(logits, dim=-1)
codes.append(samples)
if i < 6:
y_emb[:, prefix_len:] += embedding_layer(samples)
assert len(codes) == 8
return torch.stack(codes, dim=-1)
def visualize(
self,
predicts: Tuple[torch.Tensor],

View File

@ -39,24 +39,34 @@ def register_inf_check_hooks(model: nn.Module) -> None:
# default param _name is a way to capture the current value of the variable "name".
def forward_hook(_module, _input, _output, _name=name):
if isinstance(_output, Tensor):
if not torch.isfinite(_output.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.output is not finite")
try:
if not torch.isfinite(_output.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.output is not finite")
except RuntimeError: # e.g. CUDA out of memory
pass
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if isinstance(o, tuple):
o = o[0]
if not isinstance(o, Tensor):
continue
if not torch.isfinite(o.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.output[{i}] is not finite")
try:
if not torch.isfinite(o.to(torch.float32).sum()):
logging.warning(
f"The sum of {_name}.output[{i}] is not finite"
)
except RuntimeError: # e.g. CUDA out of memory
pass
# default param _name is a way to capture the current value of the variable "name".
def backward_hook(_module, _input, _output, _name=name):
if isinstance(_output, Tensor):
if not torch.isfinite(_output.to(torch.float32).sum()):
logging.warning(
f"The sum of {_name}.grad is not finite" # ": {_output}"
)
try:
if not torch.isfinite(_output.to(torch.float32).sum()):
logging.warning(f"The sum of {_name}.grad is not finite")
except RuntimeError: # e.g. CUDA out of memory
pass
elif isinstance(_output, tuple):
for i, o in enumerate(_output):
if isinstance(o, tuple):