mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
linear scheduler
This commit is contained in:
parent
f3fca0c81b
commit
84f6e0da4e
@ -46,6 +46,7 @@ from model.cfm import CFM
|
||||
from model.dit import DiT
|
||||
from model.utils import convert_char_to_pinyin
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
||||
from torch import Tensor
|
||||
|
||||
# from torch.cuda.amp import GradScaler
|
||||
@ -168,12 +169,12 @@ def get_parser():
|
||||
default="f5-tts/vocab.txt",
|
||||
help="Path to the unique text tokens file",
|
||||
)
|
||||
|
||||
# /home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
|
||||
parser.add_argument(
|
||||
"--pretrained-model-path",
|
||||
type=str,
|
||||
default="/home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt",
|
||||
help="Path to the unique text tokens file",
|
||||
default=None,
|
||||
help="Path to file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -199,6 +200,14 @@ def get_parser():
|
||||
decreases. We suggest not to change this.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decay-steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="""Number of steps that affects how rapidly the learning rate
|
||||
decreases. We suggest not to change this.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
@ -937,7 +946,16 @@ def run(rank, world_size, args):
|
||||
logging.info("About to create model")
|
||||
|
||||
model = get_model(params)
|
||||
# model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path)
|
||||
if params.pretrained_model_path:
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||
if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint:
|
||||
model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path)
|
||||
else:
|
||||
_ = load_checkpoint(
|
||||
params.pretrained_model_path,
|
||||
model=model,
|
||||
)
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
with open(f"{params.exp_dir}/model.txt", "w") as f:
|
||||
@ -990,7 +1008,12 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
|
||||
# scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
|
||||
warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps)
|
||||
decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps)
|
||||
scheduler = SequentialLR(
|
||||
optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps]
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user