code clean

This commit is contained in:
root 2025-01-20 09:07:19 +00:00
parent 29de94ee2a
commit 455366418c
9 changed files with 81 additions and 194 deletions

View File

@ -11,7 +11,7 @@ repos:
rev: 5.0.4
hooks:
- id: flake8
args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503, F722, F821"]
args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
#exclude:
# What are we ignoring here?

View File

@ -1,3 +1,19 @@
#!/usr/bin/env python3
# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py
"""
Usage:
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst
python3 f5-tts/generate_averaged_model.py \
--epoch 56 \
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--exp-dir exp/f5_small
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
bash local/compute_wer.sh $output_dir $manifest
"""
import argparse
import logging
import math
@ -62,7 +78,7 @@ def get_parser():
parser.add_argument(
"--manifest-file",
type=str,
default="/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst",
default="/path/seed_tts_eval/seedtts_testset/zh/meta.lst",
help="The manifest file in seed_tts_eval format",
)
@ -180,7 +196,6 @@ def get_inference_prompt(
batch_accum[bucket_i] += total_mel_len
if batch_accum[bucket_i] >= infer_batch_size:
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
prompts_all.append(
(
utts[bucket_i],
@ -282,7 +297,7 @@ def main():
model = get_model(args).eval().to(device)
checkpoint = torch.load(args.model_path, map_location="cpu")
if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint:
if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(model, args.model_path)
else:
_ = load_checkpoint(

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -20,8 +20,17 @@
# limitations under the License.
"""
Usage:
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
world_size=8
exp_dir=exp/ft-tts
exp_dir=exp/f5-tts-small
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size}
"""
import argparse
@ -45,13 +54,10 @@ from lhotse.utils import fix_random_seed
from model.cfm import CFM
from model.dit import DiT
from model.utils import convert_char_to_pinyin
from optim import Eden, ScaledAdam
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch import Tensor
# from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import TtsDataModule
from utils import MetricsTracker
@ -87,12 +93,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default=1024,
help="Embedding dimension in the decoder model.",
)
parser.add_argument(
"--nhead",
type=int,
default=16,
help="Number of attention heads in the Decoder layers.",
)
parser.add_argument(
"--num-decoder-layers",
type=int,
@ -156,7 +164,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=Path,
default="exp/valle_dev",
default="exp/f5",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -169,7 +177,7 @@ def get_parser():
default="f5-tts/vocab.txt",
help="Path to the unique text tokens file",
)
# /home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
parser.add_argument(
"--pretrained-model-path",
type=str,
@ -180,15 +188,9 @@ def get_parser():
parser.add_argument(
"--optimizer-name",
type=str,
default="ScaledAdam",
default="AdamW",
help="The optimizer.",
)
parser.add_argument(
"--scheduler-name",
type=str,
default="Eden",
help="The scheduler.",
)
parser.add_argument(
"--base-lr", type=float, default=0.05, help="The base learning rate."
)
@ -203,7 +205,7 @@ def get_parser():
parser.add_argument(
"--decay-steps",
type=int,
default=None,
default=1000000,
help="""Number of steps that affects how rapidly the learning rate
decreases. We suggest not to change this.""",
)
@ -286,6 +288,7 @@ def get_parser():
default=0.0,
help="Keep only utterances with duration > this.",
)
parser.add_argument(
"--filter-max-duration",
type=float,
@ -293,13 +296,6 @@ def get_parser():
help="Keep only utterances with duration < this.",
)
parser.add_argument(
"--visualize",
type=str2bool,
default=False,
help="visualize model results in eval step.",
)
parser.add_argument(
"--oom-check",
type=str2bool,
@ -383,6 +379,7 @@ def get_tokenizer(vocab_file_path: str):
def get_model(params):
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
# bigvgan 100 dim features
n_mel_channels = 100
n_fft = 1024
sampling_rate = 24_000
@ -421,7 +418,6 @@ def get_model(params):
def load_F5_TTS_pretrained_checkpoint(
model, ckpt_path, device: str = "cpu", dtype=torch.float32
):
# model = model.to(dtype)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
if "ema_model_state_dict" in checkpoint:
checkpoint["model_state_dict"] = {
@ -641,14 +637,6 @@ def compute_validation_loss(
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
# if params.visualize:
# output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
# output_dir.mkdir(parents=True, exist_ok=True)
# if isinstance(model, DDP):
# model.module.visualize(predicts, batch, output_dir=output_dir)
# else:
# model.visualize(predicts, batch, output_dir=output_dir)
return tot_loss
@ -744,7 +732,7 @@ def train_one_epoch(
scaler.scale(loss).backward()
if params.batch_idx_train >= params.accumulate_grad_steps:
if params.batch_idx_train % params.accumulate_grad_steps == 0:
if params.optimizer_name not in ["ScaledAdam", "Eve"]:
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
@ -757,9 +745,6 @@ def train_one_epoch(
# optimizer.step()
for k in range(params.accumulate_grad_steps):
if isinstance(scheduler, Eden):
scheduler.step_batch(params.batch_idx_train)
else:
scheduler.step()
set_batch_count(model, params.batch_idx_train)
@ -940,16 +925,18 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
tokenizer = get_tokenizer(params.tokens)
print("the class type of tokenizer is: ", type(tokenizer))
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path)
if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(
model, params.pretrained_model_path
)
else:
_ = load_checkpoint(
params.pretrained_model_path,
@ -984,13 +971,6 @@ def run(rank, world_size, args):
model_parameters = model.parameters()
if params.optimizer_name == "ScaledAdam":
optimizer = ScaledAdam(
model_parameters,
lr=params.base_lr,
clipping_scale=2.0,
)
elif params.optimizer_name == "AdamW":
optimizer = torch.optim.AdamW(
model_parameters,
lr=params.base_lr,
@ -998,13 +978,17 @@ def run(rank, world_size, args):
weight_decay=1e-2,
eps=1e-8,
)
else:
raise NotImplementedError()
warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps)
decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps)
warmup_scheduler = LinearLR(
optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps
)
decay_scheduler = LinearLR(
optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps
)
scheduler = SequentialLR(
optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps]
optimizer,
schedulers=[warmup_scheduler, decay_scheduler],
milestones=[params.warmup_steps],
)
optimizer.zero_grad()
@ -1062,8 +1046,6 @@ def run(rank, world_size, args):
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
if isinstance(scheduler, Eden):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
@ -1140,7 +1122,6 @@ def scan_pessimistic_batches_for_oom(
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
print(23333)
dtype = torch.float32
if params.dtype in ["bfloat16", "bf16"]:
dtype = torch.bfloat16

View File

@ -24,8 +24,6 @@ from pathlib import Path
from typing import Any, Dict, Optional
import torch
# from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset,
CutConcatenate,
@ -185,22 +183,6 @@ class TtsDataModule:
raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet."
)
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# train = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
@ -249,22 +231,6 @@ class TtsDataModule:
raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet."
)
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# validate = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
else:
validate = SpeechSynthesisDataset(
return_text=True,
@ -296,22 +262,6 @@ class TtsDataModule:
raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet."
)
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# test = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
else:
test = SpeechSynthesisDataset(
return_text=True,

View File

@ -1,15 +0,0 @@
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
#bigvganinference
model_path=/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
manifest=/home/yuekaiz/HF/valle_wenetspeech4tts_demo/wenetspeech4tts.txt
manifest=/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst
# get wenetspeech4tts
manifest_base_stem=$(basename $manifest)
mainfest_base_stem=${manifest_base_stem%.*}
output_dir=./results/f5-tts-pretrained/$mainfest_base_stem
pip install sherpa-onnx bigvganinference lhotse kaldialign sentencepiece
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir || exit 1
bash local/compute_wer.sh $output_dir $manifest

View File

@ -1,6 +1,5 @@
wav_dir=$1
wav_files=$(ls $wav_dir/*.wav)
# wav_files=$(echo $wav_files | cut -d " " -f 1)
# if wav_files is empty, then exit
if [ -z "$wav_files" ]; then
exit 1

View File

@ -1,5 +1,4 @@
#!/usr/bin/env bash
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
set -eou pipefail
@ -7,8 +6,8 @@ set -eou pipefail
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
stage=7
stop_stage=7
stage=1
stop_stage=4
dl_dir=$PWD/download
@ -101,23 +100,10 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: build monotonic_align lib (used by matcha recipes)"
for recipe in matcha; do
if [ ! -d $recipe/monotonic_align/build ]; then
cd $recipe/monotonic_align
python3 setup.py build_ext --inplace
cd ../../
else
log "monotonic_align lib for $recipe already built"
fi
done
fi
subset="Basic"
prefix="wenetspeech4tts"
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Generate fbank (used by ./matcha)"
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate fbank (used by ./f5-tts)"
mkdir -p data/fbank
if [ ! -e data/fbank/.${prefix}.done ]; then
./local/compute_mel_feat.py --dataset-parts $subset --split 100
@ -125,8 +111,8 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./matcha)"
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)"
if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then
echo "Combining ${prefix} cuts"
pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz")
@ -135,17 +121,17 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
if [ ! -e data/fbank/.${prefix}_split.done ]; then
echo "Splitting ${prefix} cuts into train, valid and test sets"
# lhotse subset --last 800 \
# data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
# data/fbank/${prefix}_cuts_validtest.jsonl.gz
# lhotse subset --first 400 \
# data/fbank/${prefix}_cuts_validtest.jsonl.gz \
# data/fbank/${prefix}_cuts_valid.jsonl.gz
# lhotse subset --last 400 \
# data/fbank/${prefix}_cuts_validtest.jsonl.gz \
# data/fbank/${prefix}_cuts_test.jsonl.gz
lhotse subset --last 800 \
data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
data/fbank/${prefix}_cuts_validtest.jsonl.gz
lhotse subset --first 400 \
data/fbank/${prefix}_cuts_validtest.jsonl.gz \
data/fbank/${prefix}_cuts_valid.jsonl.gz
lhotse subset --last 400 \
data/fbank/${prefix}_cuts_validtest.jsonl.gz \
data/fbank/${prefix}_cuts_test.jsonl.gz
# rm data/fbank/${prefix}_cuts_validtest.jsonl.gz
rm data/fbank/${prefix}_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 ))
lhotse subset --first $n \

View File

@ -1,28 +0,0 @@
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
install_flag=false
if [ "$install_flag" = true ]; then
echo "Installing packages..."
pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# lhotse tensorboard kaldialign
pip install -r requirements.txt
pip install phonemizer pypinyin sentencepiece kaldialign matplotlib h5py
apt-get update && apt-get -y install festival espeak-ng mbrola
else
echo "Skipping installation."
fi
world_size=8
#world_size=1
exp_dir=exp/f5
# pip install -r f5-tts/requirements.txt
python3 f5-tts/train.py --max-duration 300 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 8000 \
--base-lr 1e-4 --warmup-steps 5000 --average-period 200 \
--num-epochs 10 --start-epoch 1 --start-batch 20000 \
--exp-dir ${exp_dir} --world-size ${world_size}