mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
remove cosyvoice scheduler
This commit is contained in:
parent
3f8e64b9dc
commit
29de94ee2a
@ -1,7 +0,0 @@
|
|||||||
# from f5_tts.model.backbones.dit import DiT
|
|
||||||
# from f5_tts.model.backbones.mmdit import MMDiT
|
|
||||||
# from f5_tts.model.backbones.unett import UNetT
|
|
||||||
# from f5_tts.model.cfm import CFM
|
|
||||||
# from f5_tts.model.trainer import Trainer
|
|
||||||
|
|
||||||
# __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
|
@ -290,8 +290,6 @@ class CFM(nn.Module):
|
|||||||
|
|
||||||
# time step
|
# time step
|
||||||
time = torch.rand((batch,), dtype=dtype, device=self.device)
|
time = torch.rand((batch,), dtype=dtype, device=self.device)
|
||||||
# add cosyvoice cosine scheduler
|
|
||||||
time = 1 - torch.cos(time * 0.5 * torch.pi)
|
|
||||||
# TODO. noise_scheduler
|
# TODO. noise_scheduler
|
||||||
|
|
||||||
# sample xt (φ_t(x) in the paper)
|
# sample xt (φ_t(x) in the paper)
|
||||||
|
@ -998,24 +998,14 @@ def run(rank, world_size, args):
|
|||||||
weight_decay=1e-2,
|
weight_decay=1e-2,
|
||||||
eps=1e-8,
|
eps=1e-8,
|
||||||
)
|
)
|
||||||
elif params.optimizer_name == "Adam":
|
|
||||||
optimizer = torch.optim.Adam(
|
|
||||||
model_parameters,
|
|
||||||
lr=params.base_lr,
|
|
||||||
betas=(0.9, 0.95),
|
|
||||||
eps=1e-8,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if params.decay_steps:
|
|
||||||
warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps)
|
warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps)
|
||||||
decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps)
|
decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps)
|
||||||
scheduler = SequentialLR(
|
scheduler = SequentialLR(
|
||||||
optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps]
|
optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps]
|
||||||
)
|
)
|
||||||
assert 1==2
|
|
||||||
else:
|
|
||||||
scheduler = Eden(optimizer, 50000, 10, warmup_batches=params.warmup_steps)
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user