diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py b/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py index 349c7220e..5a30694b4 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py @@ -290,6 +290,8 @@ class CFM(nn.Module): # time step time = torch.rand((batch,), dtype=dtype, device=self.device) + # add cosyvoice cosine scheduler + time = 1 - torch.cos(time * 0.5 * torch.pi) # TODO. noise_scheduler # sample xt (φ_t(x) in the paper) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 35e1e8e83..4ba393af1 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -1007,13 +1007,15 @@ def run(rank, world_size, args): ) else: raise NotImplementedError() - - # scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) - warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps) - decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps) - scheduler = SequentialLR( - optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps] - ) + if params.decay_steps: + warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps) + decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps) + scheduler = SequentialLR( + optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps] + ) + assert 1==2 + else: + scheduler = Eden(optimizer, 50000, 10, warmup_batches=params.warmup_steps) optimizer.zero_grad()