From 29de94ee2a66332ad692338b70667bfdab91c14e Mon Sep 17 00:00:00 2001 From: root Date: Mon, 20 Jan 2025 08:25:30 +0000 Subject: [PATCH] remove cosyvoice scheduler --- .../TTS/f5-tts/model/__init__.py | 7 ------ egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py | 2 -- egs/wenetspeech4tts/TTS/f5-tts/train.py | 22 +++++-------------- 3 files changed, 6 insertions(+), 25 deletions(-) delete mode 100644 egs/wenetspeech4tts/TTS/f5-tts/model/__init__.py diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/__init__.py b/egs/wenetspeech4tts/TTS/f5-tts/model/__init__.py deleted file mode 100644 index 878b5c171..000000000 --- a/egs/wenetspeech4tts/TTS/f5-tts/model/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# from f5_tts.model.backbones.dit import DiT -# from f5_tts.model.backbones.mmdit import MMDiT -# from f5_tts.model.backbones.unett import UNetT -# from f5_tts.model.cfm import CFM -# from f5_tts.model.trainer import Trainer - -# __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"] diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py b/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py index 5a30694b4..349c7220e 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/cfm.py @@ -290,8 +290,6 @@ class CFM(nn.Module): # time step time = torch.rand((batch,), dtype=dtype, device=self.device) - # add cosyvoice cosine scheduler - time = 1 - torch.cos(time * 0.5 * torch.pi) # TODO. noise_scheduler # sample xt (φ_t(x) in the paper) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/train.py b/egs/wenetspeech4tts/TTS/f5-tts/train.py index 4ba393af1..6e400f10c 100755 --- a/egs/wenetspeech4tts/TTS/f5-tts/train.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/train.py @@ -998,24 +998,14 @@ def run(rank, world_size, args): weight_decay=1e-2, eps=1e-8, ) - elif params.optimizer_name == "Adam": - optimizer = torch.optim.Adam( - model_parameters, - lr=params.base_lr, - betas=(0.9, 0.95), - eps=1e-8, - ) else: raise NotImplementedError() - if params.decay_steps: - warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps) - decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps) - scheduler = SequentialLR( - optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps] - ) - assert 1==2 - else: - scheduler = Eden(optimizer, 50000, 10, warmup_batches=params.warmup_steps) + + warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps) + decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps) + scheduler = SequentialLR( + optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps] + ) optimizer.zero_grad()