From e788bb4853c7455020a829b1476426aeb189bc11 Mon Sep 17 00:00:00 2001 From: JinZr Date: Sun, 6 Oct 2024 13:38:05 +0800 Subject: [PATCH] making MSD and MPD optional --- egs/libritts/CODEC/encodec/encodec.py | 100 +++++++++++++++--------- egs/libritts/CODEC/encodec/train.py | 108 ++++++++++++-------------- 2 files changed, 115 insertions(+), 93 deletions(-) diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 470142392..a2e540dcd 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -1,6 +1,6 @@ import math import random -from typing import List +from typing import List, Optional import numpy as np import torch @@ -25,8 +25,8 @@ class Encodec(nn.Module): quantizer: nn.Module, decoder: nn.Module, multi_scale_discriminator: nn.Module, - multi_period_discriminator: nn.Module, - multi_scale_stft_discriminator: nn.Module, + multi_period_discriminator: Optional[nn.Module] = None, + multi_scale_stft_discriminator: Optional[nn.Module] = None, cache_generator_outputs: bool = False, ): super(Encodec, self).__init__() @@ -113,28 +113,42 @@ class Encodec(nn.Module): with torch.no_grad(): # do not store discriminator gradient in generator turn y, fmap = self.multi_scale_stft_discriminator(speech.contiguous()) - y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( - speech.contiguous(), - speech_hat.contiguous(), - ) - y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( - speech.contiguous(), - speech_hat.contiguous(), - ) + + gen_period_adv_loss = torch.tensor(0.0) + feature_period_loss = torch.tensor(0.0) + if self.multi_period_discriminator is not None: + y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( + speech.contiguous(), + speech_hat.contiguous(), + ) + + gen_scale_adv_loss = torch.tensor(0.0) + feature_scale_loss = torch.tensor(0.0) + if self.multi_scale_discriminator is not None: + y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( + speech.contiguous(), + speech_hat.contiguous(), + ) # calculate losses with autocast(enabled=False): gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat) - gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat) - gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat) + + if self.multi_period_discriminator is not None: + gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat) + if self.multi_scale_discriminator is not None: + gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat) feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat) - feature_period_loss = self.feature_match_loss( - feats=fmap_p, feats_hat=fmap_p_hat - ) - feature_scale_loss = self.feature_match_loss( - feats=fmap_s, feats_hat=fmap_s_hat - ) + + if self.multi_period_discriminator is not None: + feature_period_loss = self.feature_match_loss( + feats=fmap_p, feats_hat=fmap_p_hat + ) + if self.multi_scale_discriminator is not None: + feature_scale_loss = self.feature_match_loss( + feats=fmap_s, feats_hat=fmap_s_hat + ) wav_reconstruction_loss = self.wav_reconstruction_loss( x=speech, x_hat=speech_hat @@ -245,28 +259,44 @@ class Encodec(nn.Module): y_hat, fmap_hat = self.multi_scale_stft_discriminator( speech_hat.contiguous().detach() ) - y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( - speech.contiguous(), - speech_hat.contiguous().detach(), - ) - y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( - speech.contiguous(), - speech_hat.contiguous().detach(), - ) + + disc_period_real_adv_loss, disc_period_fake_adv_loss = torch.tensor( + 0.0 + ), torch.tensor(0.0) + if self.multi_period_discriminator is not None: + y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( + speech.contiguous(), + speech_hat.contiguous().detach(), + ) + + disc_scale_real_adv_loss, disc_scale_fake_adv_loss = torch.tensor( + 0.0 + ), torch.tensor(0.0) + if self.multi_scale_discriminator is not None: + y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( + speech.contiguous(), + speech_hat.contiguous().detach(), + ) # calculate losses with autocast(enabled=False): ( disc_stft_real_adv_loss, disc_stft_fake_adv_loss, ) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat) - ( - disc_period_real_adv_loss, - disc_period_fake_adv_loss, - ) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat) - ( - disc_scale_real_adv_loss, - disc_scale_fake_adv_loss, - ) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat) + if self.multi_period_discriminator is not None: + ( + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + ) = self.discriminator_adversarial_loss( + outputs=y_p, outputs_hat=y_p_hat + ) + if self.multi_scale_discriminator is not None: + ( + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + ) = self.discriminator_adversarial_loss( + outputs=y_s, outputs_hat=y_s_hat + ) stats = dict( discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(), diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 206a72a76..0adffb658 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -313,8 +313,8 @@ def get_model(params: AttributeDict) -> nn.Module: encoder=encoder, quantizer=quantizer, decoder=decoder, - multi_scale_discriminator=MultiScaleDiscriminator(), - multi_period_discriminator=MultiPeriodDiscriminator(), + multi_scale_discriminator=None, + multi_period_discriminator=None, multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator( n_filters=params.stft_discriminator_n_filters ), @@ -456,17 +456,13 @@ def train_one_epoch( forward_generator=False, ) disc_loss = ( - ( - disc_stft_real_adv_loss - + disc_stft_fake_adv_loss - + disc_period_real_adv_loss - + disc_period_fake_adv_loss - + disc_scale_real_adv_loss - + disc_scale_fake_adv_loss - ) - * d_weight - / 3 - ) + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) * d_weight for k, v in stats_d.items(): loss_info[k] = v * batch_size # update discriminator @@ -499,13 +495,11 @@ def train_one_epoch( return_sample=params.batch_idx_train % params.log_interval == 0, ) gen_adv_loss = ( - (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) - * g_weight - / 3 - ) + gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss + ) * g_weight feature_loss = ( feature_stft_loss + feature_period_loss + feature_scale_loss - ) / 3 + ) reconstruction_loss = ( params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss @@ -714,17 +708,13 @@ def compute_validation_loss( forward_generator=False, ) disc_loss = ( - ( - disc_stft_real_adv_loss - + disc_stft_fake_adv_loss - + disc_period_real_adv_loss - + disc_period_fake_adv_loss - + disc_scale_real_adv_loss - + disc_scale_fake_adv_loss - ) - * d_weight - / 3 - ) + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) * d_weight assert disc_loss.requires_grad is False for k, v in stats_d.items(): loss_info[k] = v * batch_size @@ -753,13 +743,9 @@ def compute_validation_loss( return_sample=False, ) gen_adv_loss = ( - (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) - * g_weight - / 3 - ) - feature_loss = ( - feature_stft_loss + feature_period_loss + feature_scale_loss - ) / 3 + gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss + ) * g_weight + feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss reconstruction_loss = ( params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss ) @@ -836,20 +822,16 @@ def scan_pessimistic_batches_for_oom( forward_generator=False, ) loss_d = ( - ( - disc_stft_real_adv_loss - + disc_stft_fake_adv_loss - + disc_period_real_adv_loss - + disc_period_fake_adv_loss - + disc_scale_real_adv_loss - + disc_scale_fake_adv_loss - ) - * adopt_weight( - params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, - ) - / 3 + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) * adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, ) optimizer_d.zero_grad() loss_d.backward() @@ -879,7 +861,6 @@ def scan_pessimistic_batches_for_oom( params.batch_idx_train, threshold=params.discriminator_iter_start, ) - / 3 + params.lambda_rec * ( params.lambda_wav * wav_reconstruction_loss @@ -962,9 +943,17 @@ def run(rank, world_size, args): logging.info(f"Number of parameters in decoder: {num_param_d}") num_param_q = sum([p.numel() for p in quantizer.parameters()]) logging.info(f"Number of parameters in quantizer: {num_param_q}") - num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()]) + num_param_ds = ( + sum([p.numel() for p in multi_scale_discriminator.parameters()]) + if multi_scale_discriminator is not None + else 0 + ) logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") - num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()]) + num_param_dp = ( + sum([p.numel() for p in multi_period_discriminator.parameters()]) + if multi_period_discriminator is not None + else 0 + ) logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") num_param_dstft = sum( [p.numel() for p in multi_scale_stft_discriminator.parameters()] @@ -998,12 +987,15 @@ def run(rank, world_size, args): lr=params.lr, betas=(0.5, 0.9), ) + discriminator_params = [ + multi_scale_stft_discriminator.parameters(), + ] + if multi_scale_discriminator is not None: + discriminator_params.append(multi_scale_discriminator.parameters()) + if multi_period_discriminator is not None: + discriminator_params.append(multi_period_discriminator.parameters()) optimizer_d = torch.optim.AdamW( - itertools.chain( - multi_scale_stft_discriminator.parameters(), - multi_scale_discriminator.parameters(), - multi_period_discriminator.parameters(), - ), + itertools.chain(*discriminator_params), lr=params.lr, betas=(0.5, 0.9), )