code clean

This commit is contained in:
root 2025-01-20 09:07:19 +00:00
parent 29de94ee2a
commit 455366418c
9 changed files with 81 additions and 194 deletions

View File

@ -11,7 +11,7 @@ repos:
rev: 5.0.4 rev: 5.0.4
hooks: hooks:
- id: flake8 - id: flake8
args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503, F722, F821"] args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
#exclude: #exclude:
# What are we ignoring here? # What are we ignoring here?

View File

@ -1,3 +1,19 @@
#!/usr/bin/env python3
# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py
"""
Usage:
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst
python3 f5-tts/generate_averaged_model.py \
--epoch 56 \
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--exp-dir exp/f5_small
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
bash local/compute_wer.sh $output_dir $manifest
"""
import argparse import argparse
import logging import logging
import math import math
@ -62,7 +78,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--manifest-file", "--manifest-file",
type=str, type=str,
default="/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst", default="/path/seed_tts_eval/seedtts_testset/zh/meta.lst",
help="The manifest file in seed_tts_eval format", help="The manifest file in seed_tts_eval format",
) )
@ -180,7 +196,6 @@ def get_inference_prompt(
batch_accum[bucket_i] += total_mel_len batch_accum[bucket_i] += total_mel_len
if batch_accum[bucket_i] >= infer_batch_size: if batch_accum[bucket_i] >= infer_batch_size:
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
prompts_all.append( prompts_all.append(
( (
utts[bucket_i], utts[bucket_i],
@ -282,7 +297,7 @@ def main():
model = get_model(args).eval().to(device) model = get_model(args).eval().to(device)
checkpoint = torch.load(args.model_path, map_location="cpu") checkpoint = torch.load(args.model_path, map_location="cpu")
if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint: if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(model, args.model_path) model = load_F5_TTS_pretrained_checkpoint(model, args.model_path)
else: else:
_ = load_checkpoint( _ = load_checkpoint(

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -20,8 +20,17 @@
# limitations under the License. # limitations under the License.
""" """
Usage: Usage:
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
world_size=8 world_size=8
exp_dir=exp/ft-tts exp_dir=exp/f5-tts-small
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size}
""" """
import argparse import argparse
@ -45,13 +54,10 @@ from lhotse.utils import fix_random_seed
from model.cfm import CFM from model.cfm import CFM
from model.dit import DiT from model.dit import DiT
from model.utils import convert_char_to_pinyin from model.utils import convert_char_to_pinyin
from optim import Eden, ScaledAdam
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch import Tensor from torch import Tensor
# from torch.cuda.amp import GradScaler
from torch.amp import GradScaler from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import TtsDataModule from tts_datamodule import TtsDataModule
from utils import MetricsTracker from utils import MetricsTracker
@ -87,12 +93,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default=1024, default=1024,
help="Embedding dimension in the decoder model.", help="Embedding dimension in the decoder model.",
) )
parser.add_argument( parser.add_argument(
"--nhead", "--nhead",
type=int, type=int,
default=16, default=16,
help="Number of attention heads in the Decoder layers.", help="Number of attention heads in the Decoder layers.",
) )
parser.add_argument( parser.add_argument(
"--num-decoder-layers", "--num-decoder-layers",
type=int, type=int,
@ -156,7 +164,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=Path, type=Path,
default="exp/valle_dev", default="exp/f5",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -169,7 +177,7 @@ def get_parser():
default="f5-tts/vocab.txt", default="f5-tts/vocab.txt",
help="Path to the unique text tokens file", help="Path to the unique text tokens file",
) )
# /home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
parser.add_argument( parser.add_argument(
"--pretrained-model-path", "--pretrained-model-path",
type=str, type=str,
@ -180,15 +188,9 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--optimizer-name", "--optimizer-name",
type=str, type=str,
default="ScaledAdam", default="AdamW",
help="The optimizer.", help="The optimizer.",
) )
parser.add_argument(
"--scheduler-name",
type=str,
default="Eden",
help="The scheduler.",
)
parser.add_argument( parser.add_argument(
"--base-lr", type=float, default=0.05, help="The base learning rate." "--base-lr", type=float, default=0.05, help="The base learning rate."
) )
@ -203,7 +205,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--decay-steps", "--decay-steps",
type=int, type=int,
default=None, default=1000000,
help="""Number of steps that affects how rapidly the learning rate help="""Number of steps that affects how rapidly the learning rate
decreases. We suggest not to change this.""", decreases. We suggest not to change this.""",
) )
@ -286,6 +288,7 @@ def get_parser():
default=0.0, default=0.0,
help="Keep only utterances with duration > this.", help="Keep only utterances with duration > this.",
) )
parser.add_argument( parser.add_argument(
"--filter-max-duration", "--filter-max-duration",
type=float, type=float,
@ -293,13 +296,6 @@ def get_parser():
help="Keep only utterances with duration < this.", help="Keep only utterances with duration < this.",
) )
parser.add_argument(
"--visualize",
type=str2bool,
default=False,
help="visualize model results in eval step.",
)
parser.add_argument( parser.add_argument(
"--oom-check", "--oom-check",
type=str2bool, type=str2bool,
@ -383,6 +379,7 @@ def get_tokenizer(vocab_file_path: str):
def get_model(params): def get_model(params):
vocab_char_map, vocab_size = get_tokenizer(params.tokens) vocab_char_map, vocab_size = get_tokenizer(params.tokens)
# bigvgan 100 dim features
n_mel_channels = 100 n_mel_channels = 100
n_fft = 1024 n_fft = 1024
sampling_rate = 24_000 sampling_rate = 24_000
@ -421,7 +418,6 @@ def get_model(params):
def load_F5_TTS_pretrained_checkpoint( def load_F5_TTS_pretrained_checkpoint(
model, ckpt_path, device: str = "cpu", dtype=torch.float32 model, ckpt_path, device: str = "cpu", dtype=torch.float32
): ):
# model = model.to(dtype)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
if "ema_model_state_dict" in checkpoint: if "ema_model_state_dict" in checkpoint:
checkpoint["model_state_dict"] = { checkpoint["model_state_dict"] = {
@ -641,14 +637,6 @@ def compute_validation_loss(
params.best_valid_epoch = params.cur_epoch params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value params.best_valid_loss = loss_value
# if params.visualize:
# output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
# output_dir.mkdir(parents=True, exist_ok=True)
# if isinstance(model, DDP):
# model.module.visualize(predicts, batch, output_dir=output_dir)
# else:
# model.visualize(predicts, batch, output_dir=output_dir)
return tot_loss return tot_loss
@ -744,7 +732,7 @@ def train_one_epoch(
scaler.scale(loss).backward() scaler.scale(loss).backward()
if params.batch_idx_train >= params.accumulate_grad_steps: if params.batch_idx_train >= params.accumulate_grad_steps:
if params.batch_idx_train % params.accumulate_grad_steps == 0: if params.batch_idx_train % params.accumulate_grad_steps == 0:
if params.optimizer_name not in ["ScaledAdam", "Eve"]:
# Unscales the gradients of optimizer's assigned params in-place # Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual: # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
@ -757,9 +745,6 @@ def train_one_epoch(
# optimizer.step() # optimizer.step()
for k in range(params.accumulate_grad_steps): for k in range(params.accumulate_grad_steps):
if isinstance(scheduler, Eden):
scheduler.step_batch(params.batch_idx_train)
else:
scheduler.step() scheduler.step()
set_batch_count(model, params.batch_idx_train) set_batch_count(model, params.batch_idx_train)
@ -940,16 +925,18 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
tokenizer = get_tokenizer(params.tokens) tokenizer = get_tokenizer(params.tokens)
print("the class type of tokenizer is: ", type(tokenizer))
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_model(params) model = get_model(params)
if params.pretrained_model_path: if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint: if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path) model = load_F5_TTS_pretrained_checkpoint(
model, params.pretrained_model_path
)
else: else:
_ = load_checkpoint( _ = load_checkpoint(
params.pretrained_model_path, params.pretrained_model_path,
@ -984,13 +971,6 @@ def run(rank, world_size, args):
model_parameters = model.parameters() model_parameters = model.parameters()
if params.optimizer_name == "ScaledAdam":
optimizer = ScaledAdam(
model_parameters,
lr=params.base_lr,
clipping_scale=2.0,
)
elif params.optimizer_name == "AdamW":
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
model_parameters, model_parameters,
lr=params.base_lr, lr=params.base_lr,
@ -998,13 +978,17 @@ def run(rank, world_size, args):
weight_decay=1e-2, weight_decay=1e-2,
eps=1e-8, eps=1e-8,
) )
else:
raise NotImplementedError()
warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps) warmup_scheduler = LinearLR(
decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps) optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps
)
decay_scheduler = LinearLR(
optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps
)
scheduler = SequentialLR( scheduler = SequentialLR(
optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps] optimizer,
schedulers=[warmup_scheduler, decay_scheduler],
milestones=[params.warmup_steps],
) )
optimizer.zero_grad() optimizer.zero_grad()
@ -1062,8 +1046,6 @@ def run(rank, world_size, args):
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1): for epoch in range(params.start_epoch, params.num_epochs + 1):
if isinstance(scheduler, Eden):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1) fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1) train_dl.sampler.set_epoch(epoch - 1)
@ -1140,7 +1122,6 @@ def scan_pessimistic_batches_for_oom(
"Sanity check -- see if any of the batches in epoch 1 would cause OOM." "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
) )
batches, crit_values = find_pessimistic_batches(train_dl.sampler) batches, crit_values = find_pessimistic_batches(train_dl.sampler)
print(23333)
dtype = torch.float32 dtype = torch.float32
if params.dtype in ["bfloat16", "bf16"]: if params.dtype in ["bfloat16", "bf16"]:
dtype = torch.bfloat16 dtype = torch.bfloat16

View File

@ -24,8 +24,6 @@ from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
# from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, load_manifest_lazy from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset, from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset,
CutConcatenate, CutConcatenate,
@ -185,22 +183,6 @@ class TtsDataModule:
raise NotImplementedError( raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet." "On-the-fly feature extraction is not implemented yet."
) )
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# train = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
if self.args.bucketing_sampler: if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.") logging.info("Using DynamicBucketingSampler.")
@ -249,22 +231,6 @@ class TtsDataModule:
raise NotImplementedError( raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet." "On-the-fly feature extraction is not implemented yet."
) )
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# validate = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
else: else:
validate = SpeechSynthesisDataset( validate = SpeechSynthesisDataset(
return_text=True, return_text=True,
@ -296,22 +262,6 @@ class TtsDataModule:
raise NotImplementedError( raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet." "On-the-fly feature extraction is not implemented yet."
) )
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# test = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
else: else:
test = SpeechSynthesisDataset( test = SpeechSynthesisDataset(
return_text=True, return_text=True,

View File

@ -1,15 +0,0 @@
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
#bigvganinference
model_path=/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
manifest=/home/yuekaiz/HF/valle_wenetspeech4tts_demo/wenetspeech4tts.txt
manifest=/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst
# get wenetspeech4tts
manifest_base_stem=$(basename $manifest)
mainfest_base_stem=${manifest_base_stem%.*}
output_dir=./results/f5-tts-pretrained/$mainfest_base_stem
pip install sherpa-onnx bigvganinference lhotse kaldialign sentencepiece
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir || exit 1
bash local/compute_wer.sh $output_dir $manifest

View File

@ -1,6 +1,5 @@
wav_dir=$1 wav_dir=$1
wav_files=$(ls $wav_dir/*.wav) wav_files=$(ls $wav_dir/*.wav)
# wav_files=$(echo $wav_files | cut -d " " -f 1)
# if wav_files is empty, then exit # if wav_files is empty, then exit
if [ -z "$wav_files" ]; then if [ -z "$wav_files" ]; then
exit 1 exit 1

View File

@ -1,5 +1,4 @@
#!/usr/bin/env bash #!/usr/bin/env bash
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
set -eou pipefail set -eou pipefail
@ -7,8 +6,8 @@ set -eou pipefail
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
stage=7 stage=1
stop_stage=7 stop_stage=4
dl_dir=$PWD/download dl_dir=$PWD/download
@ -101,23 +100,10 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir} python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
fi fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: build monotonic_align lib (used by matcha recipes)"
for recipe in matcha; do
if [ ! -d $recipe/monotonic_align/build ]; then
cd $recipe/monotonic_align
python3 setup.py build_ext --inplace
cd ../../
else
log "monotonic_align lib for $recipe already built"
fi
done
fi
subset="Basic" subset="Basic"
prefix="wenetspeech4tts" prefix="wenetspeech4tts"
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 6: Generate fbank (used by ./matcha)" log "Stage 5: Generate fbank (used by ./f5-tts)"
mkdir -p data/fbank mkdir -p data/fbank
if [ ! -e data/fbank/.${prefix}.done ]; then if [ ! -e data/fbank/.${prefix}.done ]; then
./local/compute_mel_feat.py --dataset-parts $subset --split 100 ./local/compute_mel_feat.py --dataset-parts $subset --split 100
@ -125,8 +111,8 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
fi fi
fi fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./matcha)" log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)"
if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then
echo "Combining ${prefix} cuts" echo "Combining ${prefix} cuts"
pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz") pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz")
@ -135,17 +121,17 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
if [ ! -e data/fbank/.${prefix}_split.done ]; then if [ ! -e data/fbank/.${prefix}_split.done ]; then
echo "Splitting ${prefix} cuts into train, valid and test sets" echo "Splitting ${prefix} cuts into train, valid and test sets"
# lhotse subset --last 800 \ lhotse subset --last 800 \
# data/fbank/${prefix}_cuts_${subset}.jsonl.gz \ data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
# data/fbank/${prefix}_cuts_validtest.jsonl.gz data/fbank/${prefix}_cuts_validtest.jsonl.gz
# lhotse subset --first 400 \ lhotse subset --first 400 \
# data/fbank/${prefix}_cuts_validtest.jsonl.gz \ data/fbank/${prefix}_cuts_validtest.jsonl.gz \
# data/fbank/${prefix}_cuts_valid.jsonl.gz data/fbank/${prefix}_cuts_valid.jsonl.gz
# lhotse subset --last 400 \ lhotse subset --last 400 \
# data/fbank/${prefix}_cuts_validtest.jsonl.gz \ data/fbank/${prefix}_cuts_validtest.jsonl.gz \
# data/fbank/${prefix}_cuts_test.jsonl.gz data/fbank/${prefix}_cuts_test.jsonl.gz
# rm data/fbank/${prefix}_cuts_validtest.jsonl.gz rm data/fbank/${prefix}_cuts_validtest.jsonl.gz
n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 )) n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 ))
lhotse subset --first $n \ lhotse subset --first $n \

View File

@ -1,28 +0,0 @@
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
install_flag=false
if [ "$install_flag" = true ]; then
echo "Installing packages..."
pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# lhotse tensorboard kaldialign
pip install -r requirements.txt
pip install phonemizer pypinyin sentencepiece kaldialign matplotlib h5py
apt-get update && apt-get -y install festival espeak-ng mbrola
else
echo "Skipping installation."
fi
world_size=8
#world_size=1
exp_dir=exp/f5
# pip install -r f5-tts/requirements.txt
python3 f5-tts/train.py --max-duration 300 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 8000 \
--base-lr 1e-4 --warmup-steps 5000 --average-period 200 \
--num-epochs 10 --start-epoch 1 --start-batch 20000 \
--exp-dir ${exp_dir} --world-size ${world_size}