diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index a4e0ec06d..ae1e34bdd 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -142,10 +142,10 @@ class DiscriminatorAdversarialLoss(torch.nn.Module): return F.mse_loss(x, x.new_zeros(x.size())) def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: - return F.relu(x.new_ones(x.size()) - x).mean() + return F.relu(torch.ones_like(x) - x).mean() def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: - return F.relu(x.new_ones(x.size()) + x).mean() + return F.relu(torch.ones_like(x) + x).mean() class FeatureLoss(torch.nn.Module): @@ -200,7 +200,7 @@ class FeatureLoss(torch.nn.Module): feats_ = feats_[:-1] for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): feat_match_loss_ += ( - (feat_hat_ - feat_).abs() / (feat_.abs().mean()) + F.l1_loss(feat_hat_, feat_.detach()) / (feat_.detach().abs().mean()) ).mean() if self.average_by_layers: feat_match_loss_ /= j + 1 @@ -272,9 +272,16 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module): mel_hat = wav_to_spec(x_hat.squeeze(1)) mel = wav_to_spec(x.squeeze(1)) - mel_loss += F.l1_loss( - mel_hat, mel, reduce=True, reduction="mean" - ) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean") + mel_loss += ( + F.l1_loss(mel_hat, mel, reduce=True, reduction="mean") + + ( + ( + (torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7)) + ** 2 + ).mean(dim=-2) + ** 0.5 + ).mean() + ) # mel_hat = self.wav_to_spec(x_hat.squeeze(1)) # mel = self.wav_to_spec(x.squeeze(1)) @@ -307,7 +314,7 @@ class WavReconstructionLoss(torch.nn.Module): Tensor: Wav loss value. """ - wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean") + wav_loss = F.l1_loss(x, x_hat) return wav_loss diff --git a/egs/libritts/CODEC/encodec/scheduler.py b/egs/libritts/CODEC/encodec/scheduler.py index 1a62e96f2..fb6ba087d 100644 --- a/egs/libritts/CODEC/encodec/scheduler.py +++ b/egs/libritts/CODEC/encodec/scheduler.py @@ -4,16 +4,40 @@ from bisect import bisect_right from torch.optim.lr_scheduler import _LRScheduler +# It will be replaced with huggingface optimization +class WarmUpLR(_LRScheduler): + """warmup_training learning rate scheduler + Args: + optimizer: optimzier(e.g. SGD) + total_iters: totoal_iters of warmup phase + """ + + def __init__(self, optimizer, iter_per_epoch, warmup_epoch, last_epoch=-1): + + self.total_iters = iter_per_epoch * warmup_epoch + self.iter_per_epoch = iter_per_epoch + super().__init__(optimizer, last_epoch) + + def get_lr(self): + """we will use the first m batches, and set the learning + rate to base_lr * m / total_iters + """ + return [ + base_lr * self.last_epoch / (self.total_iters + 1e-8) + for base_lr in self.base_lrs + ] + + class WarmupLrScheduler(_LRScheduler): def __init__( self, optimizer, - warmup_epoch=500, + warmup_iter=500, warmup_ratio=5e-4, warmup="exp", last_epoch=-1, ): - self.warmup_epoch = warmup_epoch + self.warmup_iter = warmup_iter self.warmup_ratio = warmup_ratio self.warmup = warmup super(WarmupLrScheduler, self).__init__(optimizer, last_epoch) @@ -24,7 +48,7 @@ class WarmupLrScheduler(_LRScheduler): return lrs def get_lr_ratio(self): - if self.last_epoch < self.warmup_epoch: + if self.last_epoch < self.warmup_iter: ratio = self.get_warmup_ratio() else: ratio = self.get_main_ratio() @@ -35,7 +59,7 @@ class WarmupLrScheduler(_LRScheduler): def get_warmup_ratio(self): assert self.warmup in ("linear", "exp") - alpha = self.last_epoch / self.warmup_epoch + alpha = self.last_epoch / self.warmup_iter if self.warmup == "linear": ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha elif self.warmup == "exp": @@ -48,22 +72,22 @@ class WarmupPolyLrScheduler(WarmupLrScheduler): self, optimizer, power, - max_epoch, - warmup_epoch=500, + max_iter, + warmup_iter=500, warmup_ratio=5e-4, warmup="exp", last_epoch=-1, ): self.power = power - self.max_epoch = max_epoch + self.max_iter = max_iter super(WarmupPolyLrScheduler, self).__init__( - optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + optimizer, warmup_iter, warmup_ratio, warmup, last_epoch ) def get_main_ratio(self): - real_epoch = self.last_epoch - self.warmup_epoch - real_max_epoch = self.max_epoch - self.warmup_epoch - alpha = real_epoch / real_max_epoch + real_iter = self.last_epoch - self.warmup_iter + real_max_iter = self.max_iter - self.warmup_iter + alpha = real_iter / real_max_iter ratio = (1 - alpha) ** self.power return ratio @@ -74,7 +98,7 @@ class WarmupExpLrScheduler(WarmupLrScheduler): optimizer, gamma, interval=1, - warmup_epoch=500, + warmup_iter=500, warmup_ratio=5e-4, warmup="exp", last_epoch=-1, @@ -82,12 +106,12 @@ class WarmupExpLrScheduler(WarmupLrScheduler): self.gamma = gamma self.interval = interval super(WarmupExpLrScheduler, self).__init__( - optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + optimizer, warmup_iter, warmup_ratio, warmup, last_epoch ) def get_main_ratio(self): - real_epoch = self.last_epoch - self.warmup_epoch - ratio = self.gamma ** (real_epoch // self.interval) + real_iter = self.last_epoch - self.warmup_iter + ratio = self.gamma ** (real_iter // self.interval) return ratio @@ -95,25 +119,26 @@ class WarmupCosineLrScheduler(WarmupLrScheduler): def __init__( self, optimizer, - max_epoch, + max_iter, eta_ratio=0, - warmup_epoch=500, + warmup_iter=500, warmup_ratio=5e-4, warmup="exp", last_epoch=-1, ): self.eta_ratio = eta_ratio - self.max_epoch = max_epoch + self.max_iter = max_iter super(WarmupCosineLrScheduler, self).__init__( - optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + optimizer, warmup_iter, warmup_ratio, warmup, last_epoch ) def get_main_ratio(self): - real_max_epoch = self.max_epoch - self.warmup_epoch + real_iter = self.last_epoch - self.warmup_iter + real_max_iter = self.max_iter - self.warmup_iter return ( self.eta_ratio + (1 - self.eta_ratio) - * (1 + math.cos(math.pi * self.last_epoch / real_max_epoch)) + * (1 + math.cos(math.pi * self.last_epoch / real_max_iter)) / 2 ) @@ -124,7 +149,7 @@ class WarmupStepLrScheduler(WarmupLrScheduler): optimizer, milestones: list, gamma=0.1, - warmup_epoch=500, + warmup_iter=500, warmup_ratio=5e-4, warmup="exp", last_epoch=-1, @@ -132,10 +157,10 @@ class WarmupStepLrScheduler(WarmupLrScheduler): self.milestones = milestones self.gamma = gamma super(WarmupStepLrScheduler, self).__init__( - optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + optimizer, warmup_iter, warmup_ratio, warmup, last_epoch ) def get_main_ratio(self): - real_epoch = self.last_epoch - self.warmup_epoch - ratio = self.gamma ** bisect_right(self.milestones, real_epoch) + real_iter = self.last_epoch - self.warmup_iter + ratio = self.gamma ** bisect_right(self.milestones, real_iter) return ratio diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 0c761b8ed..088dbc577 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -187,6 +187,7 @@ def get_params() -> AttributeDict: "valid_interval": 200, "env_info": get_env_info(), "sampling_rate": 24000, + "audio_normalization": False, "chunk_size": 1.0, # in seconds "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss "lambda_wav": 1.0, # loss scaling coefficient for waveform loss @@ -276,13 +277,13 @@ def get_model(params: AttributeDict) -> nn.Module: } discriminator_params = { "stft_discriminator_n_filters": 32, - "discriminator_epoch_start": 3, + "discriminator_epoch_start": 5, "n_ffts": [1024, 2048, 512], "hop_lengths": [256, 512, 128], "win_lengths": [1024, 2048, 512], } inference_params = { - "target_bw": 12, + "target_bw": 6, } params.update(generator_params) @@ -353,6 +354,11 @@ def prepare_input( :, params.sampling_rate : params.sampling_rate + params.sampling_rate ] + if params.audio_normalization: + mean = audio.mean(dim=-1, keepdim=True) + std = audio.std(dim=-1, keepdim=True) + audio = (audio - mean) / (std + 1e-7) + return audio, audio_lens, features, features_lens @@ -532,6 +538,10 @@ def train_one_epoch( save_bad_model() raise + # step per iteration + scheduler_g.step() + scheduler_d.step() + if params.print_diagnostics and batch_idx == 5: return @@ -1009,16 +1019,16 @@ def run(rank, world_size, args): scheduler_g = WarmupCosineLrScheduler( optimizer=optimizer_g, - max_epoch=params.num_epochs, + max_iter=params.num_epochs * 1500, eta_ratio=0.1, - warmup_epoch=params.discriminator_epoch_start, + warmup_iter=params.discriminator_epoch_start * 1500, warmup_ratio=1e-4, ) scheduler_d = WarmupCosineLrScheduler( optimizer=optimizer_d, - max_epoch=params.num_epochs, + max_iter=params.num_epochs * 1500, eta_ratio=0.1, - warmup_epoch=params.discriminator_epoch_start, + warmup_iter=params.discriminator_epoch_start * 1500, warmup_ratio=1e-4, ) @@ -1128,10 +1138,6 @@ def run(rank, world_size, args): best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) - # step per epoch - scheduler_g.step() - scheduler_d.step() - logging.info("Done!") if world_size > 1: