mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
code clean
This commit is contained in:
parent
29de94ee2a
commit
455366418c
@ -11,7 +11,7 @@ repos:
|
||||
rev: 5.0.4
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503, F722, F821"]
|
||||
args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
|
||||
#exclude:
|
||||
|
||||
# What are we ignoring here?
|
||||
|
@ -1,3 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py
|
||||
"""
|
||||
Usage:
|
||||
# docker: ghcr.io/swivid/f5-tts:main
|
||||
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx
|
||||
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
|
||||
manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst
|
||||
python3 f5-tts/generate_averaged_model.py \
|
||||
--epoch 56 \
|
||||
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
|
||||
--exp-dir exp/f5_small
|
||||
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
|
||||
bash local/compute_wer.sh $output_dir $manifest
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
@ -62,7 +78,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--manifest-file",
|
||||
type=str,
|
||||
default="/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst",
|
||||
default="/path/seed_tts_eval/seedtts_testset/zh/meta.lst",
|
||||
help="The manifest file in seed_tts_eval format",
|
||||
)
|
||||
|
||||
@ -180,7 +196,6 @@ def get_inference_prompt(
|
||||
batch_accum[bucket_i] += total_mel_len
|
||||
|
||||
if batch_accum[bucket_i] >= infer_batch_size:
|
||||
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
|
||||
prompts_all.append(
|
||||
(
|
||||
utts[bucket_i],
|
||||
@ -282,7 +297,7 @@ def main():
|
||||
|
||||
model = get_model(args).eval().to(device)
|
||||
checkpoint = torch.load(args.model_path, map_location="cpu")
|
||||
if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint:
|
||||
if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint:
|
||||
model = load_F5_TTS_pretrained_checkpoint(model, args.model_path)
|
||||
else:
|
||||
_ = load_checkpoint(
|
||||
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/zipformer/optim.py
|
@ -20,8 +20,17 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
# docker: ghcr.io/swivid/f5-tts:main
|
||||
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
|
||||
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
|
||||
world_size=8
|
||||
exp_dir=exp/ft-tts
|
||||
exp_dir=exp/f5-tts-small
|
||||
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
|
||||
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
|
||||
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
|
||||
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
|
||||
--exp-dir ${exp_dir} --world-size ${world_size}
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -45,13 +54,10 @@ from lhotse.utils import fix_random_seed
|
||||
from model.cfm import CFM
|
||||
from model.dit import DiT
|
||||
from model.utils import convert_char_to_pinyin
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
||||
from torch import Tensor
|
||||
|
||||
# from torch.cuda.amp import GradScaler
|
||||
from torch.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tts_datamodule import TtsDataModule
|
||||
from utils import MetricsTracker
|
||||
@ -87,12 +93,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
default=1024,
|
||||
help="Embedding dimension in the decoder model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nhead",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of attention heads in the Decoder layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decoder-layers",
|
||||
type=int,
|
||||
@ -156,7 +164,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=Path,
|
||||
default="exp/valle_dev",
|
||||
default="exp/f5",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -169,7 +177,7 @@ def get_parser():
|
||||
default="f5-tts/vocab.txt",
|
||||
help="Path to the unique text tokens file",
|
||||
)
|
||||
# /home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
|
||||
|
||||
parser.add_argument(
|
||||
"--pretrained-model-path",
|
||||
type=str,
|
||||
@ -180,15 +188,9 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--optimizer-name",
|
||||
type=str,
|
||||
default="ScaledAdam",
|
||||
default="AdamW",
|
||||
help="The optimizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler-name",
|
||||
type=str,
|
||||
default="Eden",
|
||||
help="The scheduler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=0.05, help="The base learning rate."
|
||||
)
|
||||
@ -203,7 +205,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--decay-steps",
|
||||
type=int,
|
||||
default=None,
|
||||
default=1000000,
|
||||
help="""Number of steps that affects how rapidly the learning rate
|
||||
decreases. We suggest not to change this.""",
|
||||
)
|
||||
@ -286,6 +288,7 @@ def get_parser():
|
||||
default=0.0,
|
||||
help="Keep only utterances with duration > this.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--filter-max-duration",
|
||||
type=float,
|
||||
@ -293,13 +296,6 @@ def get_parser():
|
||||
help="Keep only utterances with duration < this.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--visualize",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="visualize model results in eval step.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--oom-check",
|
||||
type=str2bool,
|
||||
@ -383,6 +379,7 @@ def get_tokenizer(vocab_file_path: str):
|
||||
|
||||
def get_model(params):
|
||||
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
|
||||
# bigvgan 100 dim features
|
||||
n_mel_channels = 100
|
||||
n_fft = 1024
|
||||
sampling_rate = 24_000
|
||||
@ -421,7 +418,6 @@ def get_model(params):
|
||||
def load_F5_TTS_pretrained_checkpoint(
|
||||
model, ckpt_path, device: str = "cpu", dtype=torch.float32
|
||||
):
|
||||
# model = model.to(dtype)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
if "ema_model_state_dict" in checkpoint:
|
||||
checkpoint["model_state_dict"] = {
|
||||
@ -641,14 +637,6 @@ def compute_validation_loss(
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
# if params.visualize:
|
||||
# output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
|
||||
# output_dir.mkdir(parents=True, exist_ok=True)
|
||||
# if isinstance(model, DDP):
|
||||
# model.module.visualize(predicts, batch, output_dir=output_dir)
|
||||
# else:
|
||||
# model.visualize(predicts, batch, output_dir=output_dir)
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
@ -744,11 +732,11 @@ def train_one_epoch(
|
||||
scaler.scale(loss).backward()
|
||||
if params.batch_idx_train >= params.accumulate_grad_steps:
|
||||
if params.batch_idx_train % params.accumulate_grad_steps == 0:
|
||||
if params.optimizer_name not in ["ScaledAdam", "Eve"]:
|
||||
# Unscales the gradients of optimizer's assigned params in-place
|
||||
scaler.unscale_(optimizer)
|
||||
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
# Unscales the gradients of optimizer's assigned params in-place
|
||||
scaler.unscale_(optimizer)
|
||||
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
@ -757,10 +745,7 @@ def train_one_epoch(
|
||||
# optimizer.step()
|
||||
|
||||
for k in range(params.accumulate_grad_steps):
|
||||
if isinstance(scheduler, Eden):
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
else:
|
||||
scheduler.step()
|
||||
scheduler.step()
|
||||
|
||||
set_batch_count(model, params.batch_idx_train)
|
||||
except: # noqa
|
||||
@ -940,16 +925,18 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
tokenizer = get_tokenizer(params.tokens)
|
||||
print("the class type of tokenizer is: ", type(tokenizer))
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
model = get_model(params)
|
||||
|
||||
if params.pretrained_model_path:
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||
if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint:
|
||||
model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path)
|
||||
if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint:
|
||||
model = load_F5_TTS_pretrained_checkpoint(
|
||||
model, params.pretrained_model_path
|
||||
)
|
||||
else:
|
||||
_ = load_checkpoint(
|
||||
params.pretrained_model_path,
|
||||
@ -984,27 +971,24 @@ def run(rank, world_size, args):
|
||||
|
||||
model_parameters = model.parameters()
|
||||
|
||||
if params.optimizer_name == "ScaledAdam":
|
||||
optimizer = ScaledAdam(
|
||||
model_parameters,
|
||||
lr=params.base_lr,
|
||||
clipping_scale=2.0,
|
||||
)
|
||||
elif params.optimizer_name == "AdamW":
|
||||
optimizer = torch.optim.AdamW(
|
||||
model_parameters,
|
||||
lr=params.base_lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=1e-2,
|
||||
eps=1e-8,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
optimizer = torch.optim.AdamW(
|
||||
model_parameters,
|
||||
lr=params.base_lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=1e-2,
|
||||
eps=1e-8,
|
||||
)
|
||||
|
||||
warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps)
|
||||
decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps)
|
||||
warmup_scheduler = LinearLR(
|
||||
optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps
|
||||
)
|
||||
decay_scheduler = LinearLR(
|
||||
optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps
|
||||
)
|
||||
scheduler = SequentialLR(
|
||||
optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps]
|
||||
optimizer,
|
||||
schedulers=[warmup_scheduler, decay_scheduler],
|
||||
milestones=[params.warmup_steps],
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
@ -1062,8 +1046,6 @@ def run(rank, world_size, args):
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
if isinstance(scheduler, Eden):
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
train_dl.sampler.set_epoch(epoch - 1)
|
||||
@ -1140,7 +1122,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
||||
)
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
print(23333)
|
||||
dtype = torch.float32
|
||||
if params.dtype in ["bfloat16", "bf16"]:
|
||||
dtype = torch.bfloat16
|
||||
|
@ -24,8 +24,6 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
# from fbank import MatchaFbank, MatchaFbankConfig
|
||||
from lhotse import CutSet, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset,
|
||||
CutConcatenate,
|
||||
@ -185,22 +183,6 @@ class TtsDataModule:
|
||||
raise NotImplementedError(
|
||||
"On-the-fly feature extraction is not implemented yet."
|
||||
)
|
||||
# sampling_rate = 22050
|
||||
# config = MatchaFbankConfig(
|
||||
# n_fft=1024,
|
||||
# n_mels=80,
|
||||
# sampling_rate=sampling_rate,
|
||||
# hop_length=256,
|
||||
# win_length=1024,
|
||||
# f_min=0,
|
||||
# f_max=8000,
|
||||
# )
|
||||
# train = SpeechSynthesisDataset(
|
||||
# return_text=True,
|
||||
# return_tokens=False,
|
||||
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
# return_cuts=self.args.return_cuts,
|
||||
# )
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
@ -249,22 +231,6 @@ class TtsDataModule:
|
||||
raise NotImplementedError(
|
||||
"On-the-fly feature extraction is not implemented yet."
|
||||
)
|
||||
# sampling_rate = 22050
|
||||
# config = MatchaFbankConfig(
|
||||
# n_fft=1024,
|
||||
# n_mels=80,
|
||||
# sampling_rate=sampling_rate,
|
||||
# hop_length=256,
|
||||
# win_length=1024,
|
||||
# f_min=0,
|
||||
# f_max=8000,
|
||||
# )
|
||||
# validate = SpeechSynthesisDataset(
|
||||
# return_text=True,
|
||||
# return_tokens=False,
|
||||
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
# return_cuts=self.args.return_cuts,
|
||||
# )
|
||||
else:
|
||||
validate = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
@ -296,22 +262,6 @@ class TtsDataModule:
|
||||
raise NotImplementedError(
|
||||
"On-the-fly feature extraction is not implemented yet."
|
||||
)
|
||||
# sampling_rate = 22050
|
||||
# config = MatchaFbankConfig(
|
||||
# n_fft=1024,
|
||||
# n_mels=80,
|
||||
# sampling_rate=sampling_rate,
|
||||
# hop_length=256,
|
||||
# win_length=1024,
|
||||
# f_min=0,
|
||||
# f_max=8000,
|
||||
# )
|
||||
# test = SpeechSynthesisDataset(
|
||||
# return_text=True,
|
||||
# return_tokens=False,
|
||||
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
|
||||
# return_cuts=self.args.return_cuts,
|
||||
# )
|
||||
else:
|
||||
test = SpeechSynthesisDataset(
|
||||
return_text=True,
|
||||
|
@ -1,15 +0,0 @@
|
||||
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
|
||||
#bigvganinference
|
||||
model_path=/home/yuekaiz/HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
|
||||
manifest=/home/yuekaiz/HF/valle_wenetspeech4tts_demo/wenetspeech4tts.txt
|
||||
manifest=/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst
|
||||
# get wenetspeech4tts
|
||||
manifest_base_stem=$(basename $manifest)
|
||||
mainfest_base_stem=${manifest_base_stem%.*}
|
||||
output_dir=./results/f5-tts-pretrained/$mainfest_base_stem
|
||||
|
||||
|
||||
pip install sherpa-onnx bigvganinference lhotse kaldialign sentencepiece
|
||||
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir || exit 1
|
||||
|
||||
bash local/compute_wer.sh $output_dir $manifest
|
@ -1,6 +1,5 @@
|
||||
wav_dir=$1
|
||||
wav_files=$(ls $wav_dir/*.wav)
|
||||
# wav_files=$(echo $wav_files | cut -d " " -f 1)
|
||||
# if wav_files is empty, then exit
|
||||
if [ -z "$wav_files" ]; then
|
||||
exit 1
|
||||
|
@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
|
||||
set -eou pipefail
|
||||
|
||||
|
||||
@ -7,8 +6,8 @@ set -eou pipefail
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
stage=7
|
||||
stop_stage=7
|
||||
stage=1
|
||||
stop_stage=4
|
||||
|
||||
dl_dir=$PWD/download
|
||||
|
||||
@ -101,23 +100,10 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir}
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: build monotonic_align lib (used by matcha recipes)"
|
||||
for recipe in matcha; do
|
||||
if [ ! -d $recipe/monotonic_align/build ]; then
|
||||
cd $recipe/monotonic_align
|
||||
python3 setup.py build_ext --inplace
|
||||
cd ../../
|
||||
else
|
||||
log "monotonic_align lib for $recipe already built"
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
subset="Basic"
|
||||
prefix="wenetspeech4tts"
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Generate fbank (used by ./matcha)"
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Generate fbank (used by ./f5-tts)"
|
||||
mkdir -p data/fbank
|
||||
if [ ! -e data/fbank/.${prefix}.done ]; then
|
||||
./local/compute_mel_feat.py --dataset-parts $subset --split 100
|
||||
@ -125,8 +111,8 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./matcha)"
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 7: Split the ${prefix} cuts into train, valid and test sets (used by ./f5-tts)"
|
||||
if [ ! -f data/fbank/${prefix}_cuts_${subset}.jsonl.gz ]; then
|
||||
echo "Combining ${prefix} cuts"
|
||||
pieces=$(find data/fbank/ -name "${prefix}_cuts_${subset}.*.jsonl.gz")
|
||||
@ -135,17 +121,17 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
if [ ! -e data/fbank/.${prefix}_split.done ]; then
|
||||
echo "Splitting ${prefix} cuts into train, valid and test sets"
|
||||
|
||||
# lhotse subset --last 800 \
|
||||
# data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
|
||||
# data/fbank/${prefix}_cuts_validtest.jsonl.gz
|
||||
# lhotse subset --first 400 \
|
||||
# data/fbank/${prefix}_cuts_validtest.jsonl.gz \
|
||||
# data/fbank/${prefix}_cuts_valid.jsonl.gz
|
||||
# lhotse subset --last 400 \
|
||||
# data/fbank/${prefix}_cuts_validtest.jsonl.gz \
|
||||
# data/fbank/${prefix}_cuts_test.jsonl.gz
|
||||
lhotse subset --last 800 \
|
||||
data/fbank/${prefix}_cuts_${subset}.jsonl.gz \
|
||||
data/fbank/${prefix}_cuts_validtest.jsonl.gz
|
||||
lhotse subset --first 400 \
|
||||
data/fbank/${prefix}_cuts_validtest.jsonl.gz \
|
||||
data/fbank/${prefix}_cuts_valid.jsonl.gz
|
||||
lhotse subset --last 400 \
|
||||
data/fbank/${prefix}_cuts_validtest.jsonl.gz \
|
||||
data/fbank/${prefix}_cuts_test.jsonl.gz
|
||||
|
||||
# rm data/fbank/${prefix}_cuts_validtest.jsonl.gz
|
||||
rm data/fbank/${prefix}_cuts_validtest.jsonl.gz
|
||||
|
||||
n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 ))
|
||||
lhotse subset --first $n \
|
||||
|
@ -1,28 +0,0 @@
|
||||
export PYTHONPATH=$PYTHONPATH:/home/yuekaiz/icefall_matcha
|
||||
|
||||
install_flag=false
|
||||
if [ "$install_flag" = true ]; then
|
||||
echo "Installing packages..."
|
||||
|
||||
pip install k2==1.24.3.dev20230524+cuda11.8.torch2.0.1 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
|
||||
# lhotse tensorboard kaldialign
|
||||
pip install -r requirements.txt
|
||||
pip install phonemizer pypinyin sentencepiece kaldialign matplotlib h5py
|
||||
|
||||
apt-get update && apt-get -y install festival espeak-ng mbrola
|
||||
else
|
||||
echo "Skipping installation."
|
||||
fi
|
||||
|
||||
world_size=8
|
||||
#world_size=1
|
||||
|
||||
exp_dir=exp/f5
|
||||
|
||||
# pip install -r f5-tts/requirements.txt
|
||||
python3 f5-tts/train.py --max-duration 300 --filter-min-duration 0.5 --filter-max-duration 20 \
|
||||
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 8000 \
|
||||
--base-lr 1e-4 --warmup-steps 5000 --average-period 200 \
|
||||
--num-epochs 10 --start-epoch 1 --start-batch 20000 \
|
||||
--exp-dir ${exp_dir} --world-size ${world_size}
|
Loading…
x
Reference in New Issue
Block a user