From 84f6e0da4ede35f5afb91d1bf152b814f2cf8270 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 9 Jan 2025 07:52:18 +0000 Subject: [PATCH] linear scheduler --- egs/wenetspeech4tts/TTS/f5-tts/train.py | 33 +++++++++++++++++++++---- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index f6a0ce0e6..35e1e8e83 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -46,6 +46,7 @@ from model.cfm import CFM from model.dit import DiT from model.utils import convert_char_to_pinyin from optim import Eden, ScaledAdam +from torch.optim.lr_scheduler import LinearLR, SequentialLR from torch import Tensor # from torch.cuda.amp import GradScaler @@ -168,12 +169,12 @@ def get_parser(): default="f5-tts/vocab.txt", help="Path to the unique text tokens file", ) - + # /home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt parser.add_argument( "--pretrained-model-path", type=str, - default="/home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", - help="Path to the unique text tokens file", + default=None, + help="Path to file", ) parser.add_argument( @@ -199,6 +200,14 @@ def get_parser(): decreases. We suggest not to change this.""", ) + parser.add_argument( + "--decay-steps", + type=int, + default=None, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + parser.add_argument( "--seed", type=int, @@ -937,7 +946,16 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_model(params) - # model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path) + if params.pretrained_model_path: + checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") + if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint: + model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path) + else: + _ = load_checkpoint( + params.pretrained_model_path, + model=model, + ) + model = model.to(device) with open(f"{params.exp_dir}/model.txt", "w") as f: @@ -990,7 +1008,12 @@ def run(rank, world_size, args): else: raise NotImplementedError() - scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) + # scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) + warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps) + decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps) + scheduler = SequentialLR( + optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps] + ) optimizer.zero_grad()