mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
linear scheduler
This commit is contained in:
parent
f3fca0c81b
commit
84f6e0da4e
@ -46,6 +46,7 @@ from model.cfm import CFM
|
|||||||
from model.dit import DiT
|
from model.dit import DiT
|
||||||
from model.utils import convert_char_to_pinyin
|
from model.utils import convert_char_to_pinyin
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
|
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
# from torch.cuda.amp import GradScaler
|
# from torch.cuda.amp import GradScaler
|
||||||
@ -168,12 +169,12 @@ def get_parser():
|
|||||||
default="f5-tts/vocab.txt",
|
default="f5-tts/vocab.txt",
|
||||||
help="Path to the unique text tokens file",
|
help="Path to the unique text tokens file",
|
||||||
)
|
)
|
||||||
|
# /home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pretrained-model-path",
|
"--pretrained-model-path",
|
||||||
type=str,
|
type=str,
|
||||||
default="/home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt",
|
default=None,
|
||||||
help="Path to the unique text tokens file",
|
help="Path to file",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -199,6 +200,14 @@ def get_parser():
|
|||||||
decreases. We suggest not to change this.""",
|
decreases. We suggest not to change this.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decay-steps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="""Number of steps that affects how rapidly the learning rate
|
||||||
|
decreases. We suggest not to change this.""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -937,7 +946,16 @@ def run(rank, world_size, args):
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
# model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path)
|
if params.pretrained_model_path:
|
||||||
|
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||||
|
if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint:
|
||||||
|
model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path)
|
||||||
|
else:
|
||||||
|
_ = load_checkpoint(
|
||||||
|
params.pretrained_model_path,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
with open(f"{params.exp_dir}/model.txt", "w") as f:
|
with open(f"{params.exp_dir}/model.txt", "w") as f:
|
||||||
@ -990,7 +1008,12 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
|
# scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
|
||||||
|
warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps)
|
||||||
|
decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps)
|
||||||
|
scheduler = SequentialLR(
|
||||||
|
optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps]
|
||||||
|
)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user