From 58f656282424c568bcb6c543fa6e81c75a6303c3 Mon Sep 17 00:00:00 2001 From: JinZr Date: Sun, 6 Oct 2024 19:07:07 +0800 Subject: [PATCH] added scheduler w/ warmup --- egs/libritts/CODEC/encodec/encodec.py | 2 +- egs/libritts/CODEC/encodec/loss.py | 43 +++++--- egs/libritts/CODEC/encodec/scheduler.py | 141 ++++++++++++++++++++++++ egs/libritts/CODEC/encodec/train.py | 61 ++++++---- 4 files changed, 209 insertions(+), 38 deletions(-) create mode 100644 egs/libritts/CODEC/encodec/scheduler.py diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 725ce5d01..aa0373bfa 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -60,7 +60,7 @@ class Encodec(nn.Module): self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss( average_by_discriminators=True, loss_type="hinge" ) - self.feature_match_loss = FeatureLoss(average_by_layers=False) + self.feature_match_loss = FeatureLoss() self.wav_reconstruction_loss = WavReconstructionLoss() self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss( sampling_rate=self.sampling_rate diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index f4188a313..a4e0ec06d 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -45,8 +45,8 @@ class GeneratorAdversarialLoss(torch.nn.Module): Tensor: Generator adversarial loss value. """ + adv_loss = 0.0 if isinstance(outputs, (tuple, list)): - adv_loss = 0.0 for i, outputs_ in enumerate(outputs): if isinstance(outputs_, (tuple, list)): # NOTE(kan-bayashi): case including feature maps @@ -55,9 +55,10 @@ class GeneratorAdversarialLoss(torch.nn.Module): if self.average_by_discriminators: adv_loss /= i + 1 else: - adv_loss = self.criterion(outputs) - - return adv_loss / len(outputs) + for i, outputs_ in enumerate(outputs): + adv_loss += self.criterion(outputs_) + adv_loss /= i + 1 + return adv_loss def _mse_loss(self, x): return F.mse_loss(x, x.new_ones(x.size())) @@ -112,9 +113,9 @@ class DiscriminatorAdversarialLoss(torch.nn.Module): Tensor: Discriminator fake loss value. """ + real_loss = 0.0 + fake_loss = 0.0 if isinstance(outputs, (tuple, list)): - real_loss = 0.0 - fake_loss = 0.0 for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): if isinstance(outputs_hat_, (tuple, list)): # NOTE(kan-bayashi): case including feature maps @@ -126,10 +127,13 @@ class DiscriminatorAdversarialLoss(torch.nn.Module): fake_loss /= i + 1 real_loss /= i + 1 else: - real_loss = self.real_criterion(outputs) - fake_loss = self.fake_criterion(outputs_hat) + for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + fake_loss /= i + 1 + real_loss /= i + 1 - return real_loss / len(outputs), fake_loss / len(outputs) + return real_loss, fake_loss def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, x.new_ones(x.size())) @@ -204,7 +208,7 @@ class FeatureLoss(torch.nn.Module): if self.average_by_discriminators: feat_match_loss /= i + 1 - return feat_match_loss / (len(feats) * len(feats[0])) + return feat_match_loss class MelSpectrogramReconstructionLoss(torch.nn.Module): @@ -233,7 +237,7 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module): self.wav_to_specs.append( MelSpectrogram( sample_rate=sampling_rate, - n_fft=s, + n_fft=max(s, 512), win_length=s, hop_length=s // 4, n_mels=n_mels, @@ -462,8 +466,15 @@ def loss_g( if __name__ == "__main__": - la = FeatureLoss(average_by_layers=False, average_by_discriminators=False) - aa = [torch.rand(192, 192) for _ in range(3)] - bb = [torch.rand(192, 192) for _ in range(3)] - print(la(bb, aa)) - print(feature_loss(aa, bb)) + # la = FeatureLoss(average_by_layers=True, average_by_discriminators=True) + # aa = [torch.rand(192, 192) for _ in range(3)] + # bb = [torch.rand(192, 192) for _ in range(3)] + # print(la(bb, aa)) + # print(feature_loss(aa, bb)) + la = GeneratorAdversarialLoss(average_by_discriminators=True, loss_type="hinge") + aa = torch.Tensor([0.1, 0.2, 0.3, 0.4]) + bb = torch.Tensor([0.4, 0.3, 0.2, 0.1]) + print(la(aa)) + print(adversarial_g_loss(aa)) + print(la(bb)) + print(adversarial_g_loss(bb)) diff --git a/egs/libritts/CODEC/encodec/scheduler.py b/egs/libritts/CODEC/encodec/scheduler.py new file mode 100644 index 000000000..1a62e96f2 --- /dev/null +++ b/egs/libritts/CODEC/encodec/scheduler.py @@ -0,0 +1,141 @@ +import math +from bisect import bisect_right + +from torch.optim.lr_scheduler import _LRScheduler + + +class WarmupLrScheduler(_LRScheduler): + def __init__( + self, + optimizer, + warmup_epoch=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.warmup_epoch = warmup_epoch + self.warmup_ratio = warmup_ratio + self.warmup = warmup + super(WarmupLrScheduler, self).__init__(optimizer, last_epoch) + + def get_lr(self): + ratio = self.get_lr_ratio() + lrs = [ratio * lr for lr in self.base_lrs] + return lrs + + def get_lr_ratio(self): + if self.last_epoch < self.warmup_epoch: + ratio = self.get_warmup_ratio() + else: + ratio = self.get_main_ratio() + return ratio + + def get_main_ratio(self): + raise NotImplementedError + + def get_warmup_ratio(self): + assert self.warmup in ("linear", "exp") + alpha = self.last_epoch / self.warmup_epoch + if self.warmup == "linear": + ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha + elif self.warmup == "exp": + ratio = self.warmup_ratio ** (1.0 - alpha) + return ratio + + +class WarmupPolyLrScheduler(WarmupLrScheduler): + def __init__( + self, + optimizer, + power, + max_epoch, + warmup_epoch=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.power = power + self.max_epoch = max_epoch + super(WarmupPolyLrScheduler, self).__init__( + optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + ) + + def get_main_ratio(self): + real_epoch = self.last_epoch - self.warmup_epoch + real_max_epoch = self.max_epoch - self.warmup_epoch + alpha = real_epoch / real_max_epoch + ratio = (1 - alpha) ** self.power + return ratio + + +class WarmupExpLrScheduler(WarmupLrScheduler): + def __init__( + self, + optimizer, + gamma, + interval=1, + warmup_epoch=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.gamma = gamma + self.interval = interval + super(WarmupExpLrScheduler, self).__init__( + optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + ) + + def get_main_ratio(self): + real_epoch = self.last_epoch - self.warmup_epoch + ratio = self.gamma ** (real_epoch // self.interval) + return ratio + + +class WarmupCosineLrScheduler(WarmupLrScheduler): + def __init__( + self, + optimizer, + max_epoch, + eta_ratio=0, + warmup_epoch=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.eta_ratio = eta_ratio + self.max_epoch = max_epoch + super(WarmupCosineLrScheduler, self).__init__( + optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + ) + + def get_main_ratio(self): + real_max_epoch = self.max_epoch - self.warmup_epoch + return ( + self.eta_ratio + + (1 - self.eta_ratio) + * (1 + math.cos(math.pi * self.last_epoch / real_max_epoch)) + / 2 + ) + + +class WarmupStepLrScheduler(WarmupLrScheduler): + def __init__( + self, + optimizer, + milestones: list, + gamma=0.1, + warmup_epoch=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.milestones = milestones + self.gamma = gamma + super(WarmupStepLrScheduler, self).__init__( + optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + ) + + def get_main_ratio(self): + real_epoch = self.last_epoch - self.warmup_epoch + ratio = self.gamma ** bisect_right(self.milestones, real_epoch) + return ratio diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 5b21c81dd..0c761b8ed 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -16,6 +16,7 @@ from encodec import Encodec from lhotse.cut import Cut from lhotse.utils import fix_random_seed from loss import adopt_weight +from scheduler import WarmupCosineLrScheduler from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP @@ -188,10 +189,10 @@ def get_params() -> AttributeDict: "sampling_rate": 24000, "chunk_size": 1.0, # in seconds "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss - "lambda_wav": 0.1, # loss scaling coefficient for waveform loss + "lambda_wav": 1.0, # loss scaling coefficient for waveform loss "lambda_feat": 3.0, # loss scaling coefficient for feat loss "lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss - "lambda_com": 1.0, # loss scaling coefficient for commitment loss + "lambda_com": 100.0, # loss scaling coefficient for commitment loss } ) @@ -260,7 +261,7 @@ def get_model(params: AttributeDict) -> nn.Module: # } # discriminator_params = { # "stft_discriminator_n_filters": 32, - # "discriminator_iter_start": 500, + # "discriminator_epoch_start": 5, # } # inference_params = { # "target_bw": 7.5, @@ -275,7 +276,10 @@ def get_model(params: AttributeDict) -> nn.Module: } discriminator_params = { "stft_discriminator_n_filters": 32, - "discriminator_iter_start": 500, + "discriminator_epoch_start": 3, + "n_ffts": [1024, 2048, 512], + "hop_lengths": [256, 512, 128], + "win_lengths": [1024, 2048, 512], } inference_params = { "target_bw": 12, @@ -316,7 +320,10 @@ def get_model(params: AttributeDict) -> nn.Module: multi_scale_discriminator=None, multi_period_discriminator=None, multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator( - n_filters=params.stft_discriminator_n_filters + n_filters=params.stft_discriminator_n_filters, + n_ffts=params.n_ffts, + hop_lengths=params.hop_lengths, + win_lengths=params.win_lengths, ), ) return model @@ -437,8 +444,8 @@ def train_one_epoch( with autocast(enabled=params.use_fp16): d_weight = adopt_weight( params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, + params.cur_epoch, + threshold=params.discriminator_epoch_start, ) # forward discriminator ( @@ -473,8 +480,8 @@ def train_one_epoch( with autocast(enabled=params.use_fp16): g_weight = adopt_weight( params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, + params.cur_epoch, + threshold=params.discriminator_epoch_start, ) # forward generator ( @@ -507,7 +514,7 @@ def train_one_epoch( gen_loss = ( gen_adv_loss + reconstruction_loss - + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + + params.lambda_feat * feature_loss + params.lambda_com * commit_loss ) for k, v in stats_g.items(): @@ -688,8 +695,8 @@ def compute_validation_loss( d_weight = adopt_weight( params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, + params.cur_epoch, + threshold=params.discriminator_epoch_start, ) # forward discriminator @@ -721,8 +728,8 @@ def compute_validation_loss( g_weight = adopt_weight( params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, + params.cur_epoch, + threshold=params.discriminator_epoch_start, ) # forward generator ( @@ -753,7 +760,7 @@ def compute_validation_loss( gen_loss = ( gen_adv_loss + reconstruction_loss - + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + + params.lambda_feat * feature_loss + params.lambda_com * commit_loss ) assert gen_loss.requires_grad is False @@ -831,8 +838,8 @@ def scan_pessimistic_batches_for_oom( + disc_scale_fake_adv_loss ) * adopt_weight( params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, + params.cur_epoch, + threshold=params.discriminator_train_start, ) optimizer_d.zero_grad() loss_d.backward() @@ -859,8 +866,8 @@ def scan_pessimistic_batches_for_oom( (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) * adopt_weight( params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, + 0, + threshold=params.discriminator_epoch_start, ) + ( params.lambda_wav * wav_reconstruction_loss @@ -1000,8 +1007,20 @@ def run(rank, world_size, args): betas=(0.5, 0.9), ) - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + scheduler_g = WarmupCosineLrScheduler( + optimizer=optimizer_g, + max_epoch=params.num_epochs, + eta_ratio=0.1, + warmup_epoch=params.discriminator_epoch_start, + warmup_ratio=1e-4, + ) + scheduler_d = WarmupCosineLrScheduler( + optimizer=optimizer_d, + max_epoch=params.num_epochs, + eta_ratio=0.1, + warmup_epoch=params.discriminator_epoch_start, + warmup_ratio=1e-4, + ) if checkpoints is not None: # load state_dict for optimizers