linear scheduler

This commit is contained in:
root 2025-01-09 07:52:18 +00:00
parent f3fca0c81b
commit 84f6e0da4e

View File

@ -46,6 +46,7 @@ from model.cfm import CFM
from model.dit import DiT
from model.utils import convert_char_to_pinyin
from optim import Eden, ScaledAdam
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch import Tensor
# from torch.cuda.amp import GradScaler
@ -168,12 +169,12 @@ def get_parser():
default="f5-tts/vocab.txt",
help="Path to the unique text tokens file",
)
# /home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
parser.add_argument(
"--pretrained-model-path",
type=str,
default="/home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt",
help="Path to the unique text tokens file",
default=None,
help="Path to file",
)
parser.add_argument(
@ -199,6 +200,14 @@ def get_parser():
decreases. We suggest not to change this.""",
)
parser.add_argument(
"--decay-steps",
type=int,
default=None,
help="""Number of steps that affects how rapidly the learning rate
decreases. We suggest not to change this.""",
)
parser.add_argument(
"--seed",
type=int,
@ -937,7 +946,16 @@ def run(rank, world_size, args):
logging.info("About to create model")
model = get_model(params)
# model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path)
if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path)
else:
_ = load_checkpoint(
params.pretrained_model_path,
model=model,
)
model = model.to(device)
with open(f"{params.exp_dir}/model.txt", "w") as f:
@ -990,7 +1008,12 @@ def run(rank, world_size, args):
else:
raise NotImplementedError()
scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
# scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps)
decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps)
scheduler = SequentialLR(
optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps]
)
optimizer.zero_grad()