making MSD and MPD optional

This commit is contained in:
JinZr 2024-10-06 13:38:05 +08:00
parent f9340cc5d7
commit e788bb4853
2 changed files with 115 additions and 93 deletions

View File

@ -1,6 +1,6 @@
import math import math
import random import random
from typing import List from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
@ -25,8 +25,8 @@ class Encodec(nn.Module):
quantizer: nn.Module, quantizer: nn.Module,
decoder: nn.Module, decoder: nn.Module,
multi_scale_discriminator: nn.Module, multi_scale_discriminator: nn.Module,
multi_period_discriminator: nn.Module, multi_period_discriminator: Optional[nn.Module] = None,
multi_scale_stft_discriminator: nn.Module, multi_scale_stft_discriminator: Optional[nn.Module] = None,
cache_generator_outputs: bool = False, cache_generator_outputs: bool = False,
): ):
super(Encodec, self).__init__() super(Encodec, self).__init__()
@ -113,28 +113,42 @@ class Encodec(nn.Module):
with torch.no_grad(): with torch.no_grad():
# do not store discriminator gradient in generator turn # do not store discriminator gradient in generator turn
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous()) y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
speech.contiguous(), gen_period_adv_loss = torch.tensor(0.0)
speech_hat.contiguous(), feature_period_loss = torch.tensor(0.0)
) if self.multi_period_discriminator is not None:
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
speech.contiguous(), speech.contiguous(),
speech_hat.contiguous(), speech_hat.contiguous(),
) )
gen_scale_adv_loss = torch.tensor(0.0)
feature_scale_loss = torch.tensor(0.0)
if self.multi_scale_discriminator is not None:
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
speech.contiguous(),
speech_hat.contiguous(),
)
# calculate losses # calculate losses
with autocast(enabled=False): with autocast(enabled=False):
gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat) gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat)
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat) if self.multi_period_discriminator is not None:
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
if self.multi_scale_discriminator is not None:
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)
feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat) feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat)
feature_period_loss = self.feature_match_loss(
feats=fmap_p, feats_hat=fmap_p_hat if self.multi_period_discriminator is not None:
) feature_period_loss = self.feature_match_loss(
feature_scale_loss = self.feature_match_loss( feats=fmap_p, feats_hat=fmap_p_hat
feats=fmap_s, feats_hat=fmap_s_hat )
) if self.multi_scale_discriminator is not None:
feature_scale_loss = self.feature_match_loss(
feats=fmap_s, feats_hat=fmap_s_hat
)
wav_reconstruction_loss = self.wav_reconstruction_loss( wav_reconstruction_loss = self.wav_reconstruction_loss(
x=speech, x_hat=speech_hat x=speech, x_hat=speech_hat
@ -245,28 +259,44 @@ class Encodec(nn.Module):
y_hat, fmap_hat = self.multi_scale_stft_discriminator( y_hat, fmap_hat = self.multi_scale_stft_discriminator(
speech_hat.contiguous().detach() speech_hat.contiguous().detach()
) )
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
speech.contiguous(), disc_period_real_adv_loss, disc_period_fake_adv_loss = torch.tensor(
speech_hat.contiguous().detach(), 0.0
) ), torch.tensor(0.0)
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( if self.multi_period_discriminator is not None:
speech.contiguous(), y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
speech_hat.contiguous().detach(), speech.contiguous(),
) speech_hat.contiguous().detach(),
)
disc_scale_real_adv_loss, disc_scale_fake_adv_loss = torch.tensor(
0.0
), torch.tensor(0.0)
if self.multi_scale_discriminator is not None:
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
speech.contiguous(),
speech_hat.contiguous().detach(),
)
# calculate losses # calculate losses
with autocast(enabled=False): with autocast(enabled=False):
( (
disc_stft_real_adv_loss, disc_stft_real_adv_loss,
disc_stft_fake_adv_loss, disc_stft_fake_adv_loss,
) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat) ) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat)
( if self.multi_period_discriminator is not None:
disc_period_real_adv_loss, (
disc_period_fake_adv_loss, disc_period_real_adv_loss,
) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat) disc_period_fake_adv_loss,
( ) = self.discriminator_adversarial_loss(
disc_scale_real_adv_loss, outputs=y_p, outputs_hat=y_p_hat
disc_scale_fake_adv_loss, )
) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat) if self.multi_scale_discriminator is not None:
(
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
) = self.discriminator_adversarial_loss(
outputs=y_s, outputs_hat=y_s_hat
)
stats = dict( stats = dict(
discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(), discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(),

View File

@ -313,8 +313,8 @@ def get_model(params: AttributeDict) -> nn.Module:
encoder=encoder, encoder=encoder,
quantizer=quantizer, quantizer=quantizer,
decoder=decoder, decoder=decoder,
multi_scale_discriminator=MultiScaleDiscriminator(), multi_scale_discriminator=None,
multi_period_discriminator=MultiPeriodDiscriminator(), multi_period_discriminator=None,
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator( multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
n_filters=params.stft_discriminator_n_filters n_filters=params.stft_discriminator_n_filters
), ),
@ -456,17 +456,13 @@ def train_one_epoch(
forward_generator=False, forward_generator=False,
) )
disc_loss = ( disc_loss = (
( disc_stft_real_adv_loss
disc_stft_real_adv_loss + disc_stft_fake_adv_loss
+ disc_stft_fake_adv_loss + disc_period_real_adv_loss
+ disc_period_real_adv_loss + disc_period_fake_adv_loss
+ disc_period_fake_adv_loss + disc_scale_real_adv_loss
+ disc_scale_real_adv_loss + disc_scale_fake_adv_loss
+ disc_scale_fake_adv_loss ) * d_weight
)
* d_weight
/ 3
)
for k, v in stats_d.items(): for k, v in stats_d.items():
loss_info[k] = v * batch_size loss_info[k] = v * batch_size
# update discriminator # update discriminator
@ -499,13 +495,11 @@ def train_one_epoch(
return_sample=params.batch_idx_train % params.log_interval == 0, return_sample=params.batch_idx_train % params.log_interval == 0,
) )
gen_adv_loss = ( gen_adv_loss = (
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss
* g_weight ) * g_weight
/ 3
)
feature_loss = ( feature_loss = (
feature_stft_loss + feature_period_loss + feature_scale_loss feature_stft_loss + feature_period_loss + feature_scale_loss
) / 3 )
reconstruction_loss = ( reconstruction_loss = (
params.lambda_wav * wav_reconstruction_loss params.lambda_wav * wav_reconstruction_loss
+ mel_reconstruction_loss + mel_reconstruction_loss
@ -714,17 +708,13 @@ def compute_validation_loss(
forward_generator=False, forward_generator=False,
) )
disc_loss = ( disc_loss = (
( disc_stft_real_adv_loss
disc_stft_real_adv_loss + disc_stft_fake_adv_loss
+ disc_stft_fake_adv_loss + disc_period_real_adv_loss
+ disc_period_real_adv_loss + disc_period_fake_adv_loss
+ disc_period_fake_adv_loss + disc_scale_real_adv_loss
+ disc_scale_real_adv_loss + disc_scale_fake_adv_loss
+ disc_scale_fake_adv_loss ) * d_weight
)
* d_weight
/ 3
)
assert disc_loss.requires_grad is False assert disc_loss.requires_grad is False
for k, v in stats_d.items(): for k, v in stats_d.items():
loss_info[k] = v * batch_size loss_info[k] = v * batch_size
@ -753,13 +743,9 @@ def compute_validation_loss(
return_sample=False, return_sample=False,
) )
gen_adv_loss = ( gen_adv_loss = (
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss
* g_weight ) * g_weight
/ 3 feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss
)
feature_loss = (
feature_stft_loss + feature_period_loss + feature_scale_loss
) / 3
reconstruction_loss = ( reconstruction_loss = (
params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss
) )
@ -836,20 +822,16 @@ def scan_pessimistic_batches_for_oom(
forward_generator=False, forward_generator=False,
) )
loss_d = ( loss_d = (
( disc_stft_real_adv_loss
disc_stft_real_adv_loss + disc_stft_fake_adv_loss
+ disc_stft_fake_adv_loss + disc_period_real_adv_loss
+ disc_period_real_adv_loss + disc_period_fake_adv_loss
+ disc_period_fake_adv_loss + disc_scale_real_adv_loss
+ disc_scale_real_adv_loss + disc_scale_fake_adv_loss
+ disc_scale_fake_adv_loss ) * adopt_weight(
) params.lambda_adv,
* adopt_weight( params.batch_idx_train,
params.lambda_adv, threshold=params.discriminator_iter_start,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
/ 3
) )
optimizer_d.zero_grad() optimizer_d.zero_grad()
loss_d.backward() loss_d.backward()
@ -879,7 +861,6 @@ def scan_pessimistic_batches_for_oom(
params.batch_idx_train, params.batch_idx_train,
threshold=params.discriminator_iter_start, threshold=params.discriminator_iter_start,
) )
/ 3
+ params.lambda_rec + params.lambda_rec
* ( * (
params.lambda_wav * wav_reconstruction_loss params.lambda_wav * wav_reconstruction_loss
@ -962,9 +943,17 @@ def run(rank, world_size, args):
logging.info(f"Number of parameters in decoder: {num_param_d}") logging.info(f"Number of parameters in decoder: {num_param_d}")
num_param_q = sum([p.numel() for p in quantizer.parameters()]) num_param_q = sum([p.numel() for p in quantizer.parameters()])
logging.info(f"Number of parameters in quantizer: {num_param_q}") logging.info(f"Number of parameters in quantizer: {num_param_q}")
num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()]) num_param_ds = (
sum([p.numel() for p in multi_scale_discriminator.parameters()])
if multi_scale_discriminator is not None
else 0
)
logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}")
num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()]) num_param_dp = (
sum([p.numel() for p in multi_period_discriminator.parameters()])
if multi_period_discriminator is not None
else 0
)
logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}")
num_param_dstft = sum( num_param_dstft = sum(
[p.numel() for p in multi_scale_stft_discriminator.parameters()] [p.numel() for p in multi_scale_stft_discriminator.parameters()]
@ -998,12 +987,15 @@ def run(rank, world_size, args):
lr=params.lr, lr=params.lr,
betas=(0.5, 0.9), betas=(0.5, 0.9),
) )
discriminator_params = [
multi_scale_stft_discriminator.parameters(),
]
if multi_scale_discriminator is not None:
discriminator_params.append(multi_scale_discriminator.parameters())
if multi_period_discriminator is not None:
discriminator_params.append(multi_period_discriminator.parameters())
optimizer_d = torch.optim.AdamW( optimizer_d = torch.optim.AdamW(
itertools.chain( itertools.chain(*discriminator_params),
multi_scale_stft_discriminator.parameters(),
multi_scale_discriminator.parameters(),
multi_period_discriminator.parameters(),
),
lr=params.lr, lr=params.lr,
betas=(0.5, 0.9), betas=(0.5, 0.9),
) )