From 455366418c7be26f9a5d00613a56e434d0a01caa Mon Sep 17 00:00:00 2001 From: root Date: Mon, 20 Jan 2025 09:07:19 +0000 Subject: [PATCH] code clean --- .pre-commit-config.yaml | 2 +- egs/wenetspeech4tts/TTS/f5-tts/infer.py | 21 +++- egs/wenetspeech4tts/TTS/f5-tts/optim.py | 1 - egs/wenetspeech4tts/TTS/f5-tts/train.py | 111 ++++++++---------- .../TTS/f5-tts/tts_datamodule.py | 50 -------- egs/wenetspeech4tts/TTS/infer_f5.sh | 15 --- egs/wenetspeech4tts/TTS/local/compute_wer.sh | 1 - egs/wenetspeech4tts/TTS/prepare.sh | 46 +++----- egs/wenetspeech4tts/TTS/train_f5.sh | 28 ----- 9 files changed, 81 insertions(+), 194 deletions(-) delete mode 120000 egs/wenetspeech4tts/TTS/f5-tts/optim.py delete mode 100644 egs/wenetspeech4tts/TTS/infer_f5.sh delete mode 100644 egs/wenetspeech4tts/TTS/train_f5.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 07dd89cda..ed7effd6a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: rev: 5.0.4 hooks: - id: flake8 - args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503, F722, F821"] + args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"] #exclude: # What are we ignoring here? diff --git a/egs/wenetspeech4tts/TTS/f5-tts/infer.py b/egs/wenetspeech4tts/TTS/f5-tts/infer.py index 8d38af1ca..02e5f0f4d 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/infer.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/infer.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python3 +# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py +""" +Usage: +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx +# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x +manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst +python3 f5-tts/generate_averaged_model.py \ + --epoch 56 \ + --avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \ + --exp-dir exp/f5_small +accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18 +bash local/compute_wer.sh $output_dir $manifest +""" import argparse import logging import math @@ -62,7 +78,7 @@ def get_parser(): parser.add_argument( "--manifest-file", type=str, - default="/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst", + default="/path/seed_tts_eval/seedtts_testset/zh/meta.lst", help="The manifest file in seed_tts_eval format", ) @@ -180,7 +196,6 @@ def get_inference_prompt( batch_accum[bucket_i] += total_mel_len if batch_accum[bucket_i] >= infer_batch_size: - # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}") prompts_all.append( ( utts[bucket_i], @@ -282,7 +297,7 @@ def main(): model = get_model(args).eval().to(device) checkpoint = torch.load(args.model_path, map_location="cpu") - if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint: + if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: model = load_F5_TTS_pretrained_checkpoint(model, args.model_path) else: _ = load_checkpoint( diff --git a/egs/wenetspeech4tts/TTS/f5-tts/optim.py b/egs/wenetspeech4tts/TTS/f5-tts/optim.py deleted file mode 120000 index 5eaa3cffd..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 6e400f10c..c1153360c 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -20,8 +20,17 @@ # limitations under the License. """ Usage: +# docker: ghcr.io/swivid/f5-tts:main +# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html +# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece +# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x world_size=8 -exp_dir=exp/ft-tts +exp_dir=exp/f5-tts-small +python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \ + --base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \ + --num-decoder-layers 18 --nhead 12 --decoder-dim 768 \ + --exp-dir ${exp_dir} --world-size ${world_size} """ import argparse @@ -45,13 +54,10 @@ from lhotse.utils import fix_random_seed from model.cfm import CFM from model.dit import DiT from model.utils import convert_char_to_pinyin -from optim import Eden, ScaledAdam -from torch.optim.lr_scheduler import LinearLR, SequentialLR from torch import Tensor - -# from torch.cuda.amp import GradScaler from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import LinearLR, SequentialLR from torch.utils.tensorboard import SummaryWriter from tts_datamodule import TtsDataModule from utils import MetricsTracker @@ -87,12 +93,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): default=1024, help="Embedding dimension in the decoder model.", ) + parser.add_argument( "--nhead", type=int, default=16, help="Number of attention heads in the Decoder layers.", ) + parser.add_argument( "--num-decoder-layers", type=int, @@ -156,7 +164,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=Path, - default="exp/valle_dev", + default="exp/f5", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -169,7 +177,7 @@ def get_parser(): default="f5-tts/vocab.txt", help="Path to the unique text tokens file", ) - # /home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt + parser.add_argument( "--pretrained-model-path", type=str, @@ -180,15 +188,9 @@ def get_parser(): parser.add_argument( "--optimizer-name", type=str, - default="ScaledAdam", + default="AdamW", help="The optimizer.", ) - parser.add_argument( - "--scheduler-name", - type=str, - default="Eden", - help="The scheduler.", - ) parser.add_argument( "--base-lr", type=float, default=0.05, help="The base learning rate." ) @@ -203,7 +205,7 @@ def get_parser(): parser.add_argument( "--decay-steps", type=int, - default=None, + default=1000000, help="""Number of steps that affects how rapidly the learning rate decreases. We suggest not to change this.""", ) @@ -286,6 +288,7 @@ def get_parser(): default=0.0, help="Keep only utterances with duration > this.", ) + parser.add_argument( "--filter-max-duration", type=float, @@ -293,13 +296,6 @@ def get_parser(): help="Keep only utterances with duration < this.", ) - parser.add_argument( - "--visualize", - type=str2bool, - default=False, - help="visualize model results in eval step.", - ) - parser.add_argument( "--oom-check", type=str2bool, @@ -383,6 +379,7 @@ def get_tokenizer(vocab_file_path: str): def get_model(params): vocab_char_map, vocab_size = get_tokenizer(params.tokens) + # bigvgan 100 dim features n_mel_channels = 100 n_fft = 1024 sampling_rate = 24_000 @@ -421,7 +418,6 @@ def get_model(params): def load_F5_TTS_pretrained_checkpoint( model, ckpt_path, device: str = "cpu", dtype=torch.float32 ): - # model = model.to(dtype) checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) if "ema_model_state_dict" in checkpoint: checkpoint["model_state_dict"] = { @@ -641,14 +637,6 @@ def compute_validation_loss( params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_value - # if params.visualize: - # output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}") - # output_dir.mkdir(parents=True, exist_ok=True) - # if isinstance(model, DDP): - # model.module.visualize(predicts, batch, output_dir=output_dir) - # else: - # model.visualize(predicts, batch, output_dir=output_dir) - return tot_loss @@ -744,11 +732,11 @@ def train_one_epoch( scaler.scale(loss).backward() if params.batch_idx_train >= params.accumulate_grad_steps: if params.batch_idx_train % params.accumulate_grad_steps == 0: - if params.optimizer_name not in ["ScaledAdam", "Eve"]: - # Unscales the gradients of optimizer's assigned params in-place - scaler.unscale_(optimizer) - # Since the gradients of optimizer's assigned params are unscaled, clips as usual: - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + # Unscales the gradients of optimizer's assigned params in-place + scaler.unscale_(optimizer) + # Since the gradients of optimizer's assigned params are unscaled, clips as usual: + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() @@ -757,10 +745,7 @@ def train_one_epoch( # optimizer.step() for k in range(params.accumulate_grad_steps): - if isinstance(scheduler, Eden): - scheduler.step_batch(params.batch_idx_train) - else: - scheduler.step() + scheduler.step() set_batch_count(model, params.batch_idx_train) except: # noqa @@ -940,16 +925,18 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") tokenizer = get_tokenizer(params.tokens) - print("the class type of tokenizer is: ", type(tokenizer)) logging.info(params) logging.info("About to create model") model = get_model(params) + if params.pretrained_model_path: checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") - if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint: - model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path) + if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint: + model = load_F5_TTS_pretrained_checkpoint( + model, params.pretrained_model_path + ) else: _ = load_checkpoint( params.pretrained_model_path, @@ -984,27 +971,24 @@ def run(rank, world_size, args): model_parameters = model.parameters() - if params.optimizer_name == "ScaledAdam": - optimizer = ScaledAdam( - model_parameters, - lr=params.base_lr, - clipping_scale=2.0, - ) - elif params.optimizer_name == "AdamW": - optimizer = torch.optim.AdamW( - model_parameters, - lr=params.base_lr, - betas=(0.9, 0.95), - weight_decay=1e-2, - eps=1e-8, - ) - else: - raise NotImplementedError() + optimizer = torch.optim.AdamW( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + weight_decay=1e-2, + eps=1e-8, + ) - warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps) - decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps) + warmup_scheduler = LinearLR( + optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps + ) + decay_scheduler = LinearLR( + optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps + ) scheduler = SequentialLR( - optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps] + optimizer, + schedulers=[warmup_scheduler, decay_scheduler], + milestones=[params.warmup_steps], ) optimizer.zero_grad() @@ -1062,8 +1046,6 @@ def run(rank, world_size, args): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): - if isinstance(scheduler, Eden): - scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) @@ -1140,7 +1122,6 @@ def scan_pessimistic_batches_for_oom( "Sanity check -- see if any of the batches in epoch 1 would cause OOM." ) batches, crit_values = find_pessimistic_batches(train_dl.sampler) - print(23333) dtype = torch.float32 if params.dtype in ["bfloat16", "bf16"]: dtype = torch.bfloat16 diff --git a/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py index b544f1d96..80ba17318 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py @@ -24,8 +24,6 @@ from pathlib import Path from typing import Any, Dict, Optional import torch - -# from fbank import MatchaFbank, MatchaFbankConfig from lhotse import CutSet, load_manifest_lazy from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset, CutConcatenate, @@ -185,22 +183,6 @@ class TtsDataModule: raise NotImplementedError( "On-the-fly feature extraction is not implemented yet." ) - # sampling_rate = 22050 - # config = MatchaFbankConfig( - # n_fft=1024, - # n_mels=80, - # sampling_rate=sampling_rate, - # hop_length=256, - # win_length=1024, - # f_min=0, - # f_max=8000, - # ) - # train = SpeechSynthesisDataset( - # return_text=True, - # return_tokens=False, - # feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), - # return_cuts=self.args.return_cuts, - # ) if self.args.bucketing_sampler: logging.info("Using DynamicBucketingSampler.") @@ -249,22 +231,6 @@ class TtsDataModule: raise NotImplementedError( "On-the-fly feature extraction is not implemented yet." ) - # sampling_rate = 22050 - # config = MatchaFbankConfig( - # n_fft=1024, - # n_mels=80, - # sampling_rate=sampling_rate, - # hop_length=256, - # win_length=1024, - # f_min=0, - # f_max=8000, - # ) - # validate = SpeechSynthesisDataset( - # return_text=True, - # return_tokens=False, - # feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), - # return_cuts=self.args.return_cuts, - # ) else: validate = SpeechSynthesisDataset( return_text=True, @@ -296,22 +262,6 @@ class TtsDataModule: raise NotImplementedError( "On-the-fly feature extraction is not implemented yet." ) - # sampling_rate = 22050 - # config = MatchaFbankConfig( - # n_fft=1024, - # n_mels=80, - # sampling_rate=sampling_rate, - # hop_length=256, - # win_length=1024, - # f_min=0, - # f_max=8000, - # ) - # test = SpeechSynthesisDataset( - # return_text=True, - # return_tokens=False, - # feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), - # return_cuts=self.args.return_cuts, - # ) else: test = SpeechSynthesisDataset( return_text=True, diff --git a/egs/wenetspeech4tts/TTS/infer_f5.sh b/egs/wenetspeech4tts/TTS/infer_f5.sh deleted file mode 100644 index eee412e5a..000000000 --- a/egs/wenetspeech4tts/TTS/infer_f5.sh +++ /dev/null @@ -1,15 +0,0 @@ -export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha -#bigvganinference -model_path=/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt -manifest=/home/yuekaiz/HF/valle_wenetspeech4tts_demo/wenetspeech4tts.txt -manifest=/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst -# get wenetspeech4tts -manifest_base_stem=$(basename $manifest) -mainfest_base_stem=${manifest_base_stem%.*} -output_dir=./results/f5-tts-pretrained/$mainfest_base_stem - - -pip install sherpa-onnx bigvganinference lhotse kaldialign sentencepiece -accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir || exit 1 - -bash local/compute_wer.sh $output_dir $manifest diff --git a/egs/wenetspeech4tts/TTS/local/compute_wer.sh b/egs/wenetspeech4tts/TTS/local/compute_wer.sh index 2a214cd67..283546383 100644 --- a/egs/wenetspeech4tts/TTS/local/compute_wer.sh +++ b/egs/wenetspeech4tts/TTS/local/compute_wer.sh @@ -1,6 +1,5 @@ wav_dir=$1 wav_files=$(ls $wav_dir/*.wav) -# wav_files=$(echo $wav_files | cut -d " " -f 1) # if wav_files is empty, then exit if [ -z "$wav_files" ]; then exit 1 diff --git a/egs/wenetspeech4tts/TTS/prepare.sh b/egs/wenetspeech4tts/TTS/prepare.sh index 7b800d87e..d62f74a78 100755 --- a/egs/wenetspeech4tts/TTS/prepare.sh +++ b/egs/wenetspeech4tts/TTS/prepare.sh @@ -1,5 +1,4 @@ #!/usr/bin/env bash -export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha set -eou pipefail @@ -7,8 +6,8 @@ set -eou pipefail # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python -stage=7 -stop_stage=7 +stage=1 +stop_stage=4 dl_dir=$PWD/download @@ -101,23 +100,10 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir} fi -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: build monotonic_align lib (used by matcha recipes)" - for recipe in matcha; do - if [ ! -d $recipe/monotonic_align/build ]; then - cd $recipe/monotonic_align - python3 setup.py build_ext --inplace - cd ../../ - else - log "monotonic_align lib for $recipe already built" - fi - done -fi - subset="Basic" prefix="wenetspeech4tts" -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Generate fbank (used by ./matcha)" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate fbank (used by ./f5-tts)" mkdir -p data/fbank if [ ! -e data/fbank/.${prefix}.done ]; then ./local/compute_mel_feat.py --dataset-parts $subset --split 100 @@ -125,8 +111,8 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then fi fi -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./matcha)" +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)" if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then echo "Combining ${prefix} cuts" pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz") @@ -135,17 +121,17 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then if [ ! -e data/fbank/.${prefix}_split.done ]; then echo "Splitting ${prefix} cuts into train, valid and test sets" - # lhotse subset --last 800 \ - # data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ - # data/fbank/${prefix}_cuts_validtest.jsonl.gz - # lhotse subset --first 400 \ - # data/fbank/${prefix}_cuts_validtest.jsonl.gz \ - # data/fbank/${prefix}_cuts_valid.jsonl.gz - # lhotse subset --last 400 \ - # data/fbank/${prefix}_cuts_validtest.jsonl.gz \ - # data/fbank/${prefix}_cuts_test.jsonl.gz + lhotse subset --last 800 \ + data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ + data/fbank/${prefix}_cuts_validtest.jsonl.gz + lhotse subset --first 400 \ + data/fbank/${prefix}_cuts_validtest.jsonl.gz \ + data/fbank/${prefix}_cuts_valid.jsonl.gz + lhotse subset --last 400 \ + data/fbank/${prefix}_cuts_validtest.jsonl.gz \ + data/fbank/${prefix}_cuts_test.jsonl.gz - # rm data/fbank/${prefix}_cuts_validtest.jsonl.gz + rm data/fbank/${prefix}_cuts_validtest.jsonl.gz n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 )) lhotse subset --first $n \ diff --git a/egs/wenetspeech4tts/TTS/train_f5.sh b/egs/wenetspeech4tts/TTS/train_f5.sh deleted file mode 100644 index f29563531..000000000 --- a/egs/wenetspeech4tts/TTS/train_f5.sh +++ /dev/null @@ -1,28 +0,0 @@ -export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha - -install_flag=false -if [ "$install_flag" = true ]; then - echo "Installing packages..." - - pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html - # pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html - # lhotse tensorboard kaldialign - pip install -r requirements.txt - pip install phonemizer pypinyin sentencepiece kaldialign matplotlib h5py - - apt-get update && apt-get -y install festival espeak-ng mbrola -else - echo "Skipping installation." -fi - -world_size=8 -#world_size=1 - -exp_dir=exp/f5 - -# pip install -r f5-tts/requirements.txt -python3 f5-tts/train.py --max-duration 300 --filter-min-duration 0.5 --filter-max-duration 20 \ - --num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 8000 \ - --base-lr 1e-4 --warmup-steps 5000 --average-period 200 \ - --num-epochs 10 --start-epoch 1 --start-batch 20000 \ - --exp-dir ${exp_dir} --world-size ${world_size}