mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 02:06:13 +00:00
making MSD and MPD optional
This commit is contained in:
parent
f9340cc5d7
commit
e788bb4853
@ -1,6 +1,6 @@
|
||||
import math
|
||||
import random
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -25,8 +25,8 @@ class Encodec(nn.Module):
|
||||
quantizer: nn.Module,
|
||||
decoder: nn.Module,
|
||||
multi_scale_discriminator: nn.Module,
|
||||
multi_period_discriminator: nn.Module,
|
||||
multi_scale_stft_discriminator: nn.Module,
|
||||
multi_period_discriminator: Optional[nn.Module] = None,
|
||||
multi_scale_stft_discriminator: Optional[nn.Module] = None,
|
||||
cache_generator_outputs: bool = False,
|
||||
):
|
||||
super(Encodec, self).__init__()
|
||||
@ -113,28 +113,42 @@ class Encodec(nn.Module):
|
||||
with torch.no_grad():
|
||||
# do not store discriminator gradient in generator turn
|
||||
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
|
||||
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous(),
|
||||
)
|
||||
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous(),
|
||||
)
|
||||
|
||||
gen_period_adv_loss = torch.tensor(0.0)
|
||||
feature_period_loss = torch.tensor(0.0)
|
||||
if self.multi_period_discriminator is not None:
|
||||
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous(),
|
||||
)
|
||||
|
||||
gen_scale_adv_loss = torch.tensor(0.0)
|
||||
feature_scale_loss = torch.tensor(0.0)
|
||||
if self.multi_scale_discriminator is not None:
|
||||
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous(),
|
||||
)
|
||||
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat)
|
||||
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
|
||||
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)
|
||||
|
||||
if self.multi_period_discriminator is not None:
|
||||
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
|
||||
if self.multi_scale_discriminator is not None:
|
||||
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)
|
||||
|
||||
feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat)
|
||||
feature_period_loss = self.feature_match_loss(
|
||||
feats=fmap_p, feats_hat=fmap_p_hat
|
||||
)
|
||||
feature_scale_loss = self.feature_match_loss(
|
||||
feats=fmap_s, feats_hat=fmap_s_hat
|
||||
)
|
||||
|
||||
if self.multi_period_discriminator is not None:
|
||||
feature_period_loss = self.feature_match_loss(
|
||||
feats=fmap_p, feats_hat=fmap_p_hat
|
||||
)
|
||||
if self.multi_scale_discriminator is not None:
|
||||
feature_scale_loss = self.feature_match_loss(
|
||||
feats=fmap_s, feats_hat=fmap_s_hat
|
||||
)
|
||||
|
||||
wav_reconstruction_loss = self.wav_reconstruction_loss(
|
||||
x=speech, x_hat=speech_hat
|
||||
@ -245,28 +259,44 @@ class Encodec(nn.Module):
|
||||
y_hat, fmap_hat = self.multi_scale_stft_discriminator(
|
||||
speech_hat.contiguous().detach()
|
||||
)
|
||||
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous().detach(),
|
||||
)
|
||||
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous().detach(),
|
||||
)
|
||||
|
||||
disc_period_real_adv_loss, disc_period_fake_adv_loss = torch.tensor(
|
||||
0.0
|
||||
), torch.tensor(0.0)
|
||||
if self.multi_period_discriminator is not None:
|
||||
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous().detach(),
|
||||
)
|
||||
|
||||
disc_scale_real_adv_loss, disc_scale_fake_adv_loss = torch.tensor(
|
||||
0.0
|
||||
), torch.tensor(0.0)
|
||||
if self.multi_scale_discriminator is not None:
|
||||
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous().detach(),
|
||||
)
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
(
|
||||
disc_stft_real_adv_loss,
|
||||
disc_stft_fake_adv_loss,
|
||||
) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat)
|
||||
(
|
||||
disc_period_real_adv_loss,
|
||||
disc_period_fake_adv_loss,
|
||||
) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat)
|
||||
(
|
||||
disc_scale_real_adv_loss,
|
||||
disc_scale_fake_adv_loss,
|
||||
) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat)
|
||||
if self.multi_period_discriminator is not None:
|
||||
(
|
||||
disc_period_real_adv_loss,
|
||||
disc_period_fake_adv_loss,
|
||||
) = self.discriminator_adversarial_loss(
|
||||
outputs=y_p, outputs_hat=y_p_hat
|
||||
)
|
||||
if self.multi_scale_discriminator is not None:
|
||||
(
|
||||
disc_scale_real_adv_loss,
|
||||
disc_scale_fake_adv_loss,
|
||||
) = self.discriminator_adversarial_loss(
|
||||
outputs=y_s, outputs_hat=y_s_hat
|
||||
)
|
||||
|
||||
stats = dict(
|
||||
discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(),
|
||||
|
@ -313,8 +313,8 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
encoder=encoder,
|
||||
quantizer=quantizer,
|
||||
decoder=decoder,
|
||||
multi_scale_discriminator=MultiScaleDiscriminator(),
|
||||
multi_period_discriminator=MultiPeriodDiscriminator(),
|
||||
multi_scale_discriminator=None,
|
||||
multi_period_discriminator=None,
|
||||
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
|
||||
n_filters=params.stft_discriminator_n_filters
|
||||
),
|
||||
@ -456,17 +456,13 @@ def train_one_epoch(
|
||||
forward_generator=False,
|
||||
)
|
||||
disc_loss = (
|
||||
(
|
||||
disc_stft_real_adv_loss
|
||||
+ disc_stft_fake_adv_loss
|
||||
+ disc_period_real_adv_loss
|
||||
+ disc_period_fake_adv_loss
|
||||
+ disc_scale_real_adv_loss
|
||||
+ disc_scale_fake_adv_loss
|
||||
)
|
||||
* d_weight
|
||||
/ 3
|
||||
)
|
||||
disc_stft_real_adv_loss
|
||||
+ disc_stft_fake_adv_loss
|
||||
+ disc_period_real_adv_loss
|
||||
+ disc_period_fake_adv_loss
|
||||
+ disc_scale_real_adv_loss
|
||||
+ disc_scale_fake_adv_loss
|
||||
) * d_weight
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
# update discriminator
|
||||
@ -499,13 +495,11 @@ def train_one_epoch(
|
||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||
)
|
||||
gen_adv_loss = (
|
||||
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
||||
* g_weight
|
||||
/ 3
|
||||
)
|
||||
gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss
|
||||
) * g_weight
|
||||
feature_loss = (
|
||||
feature_stft_loss + feature_period_loss + feature_scale_loss
|
||||
) / 3
|
||||
)
|
||||
reconstruction_loss = (
|
||||
params.lambda_wav * wav_reconstruction_loss
|
||||
+ mel_reconstruction_loss
|
||||
@ -714,17 +708,13 @@ def compute_validation_loss(
|
||||
forward_generator=False,
|
||||
)
|
||||
disc_loss = (
|
||||
(
|
||||
disc_stft_real_adv_loss
|
||||
+ disc_stft_fake_adv_loss
|
||||
+ disc_period_real_adv_loss
|
||||
+ disc_period_fake_adv_loss
|
||||
+ disc_scale_real_adv_loss
|
||||
+ disc_scale_fake_adv_loss
|
||||
)
|
||||
* d_weight
|
||||
/ 3
|
||||
)
|
||||
disc_stft_real_adv_loss
|
||||
+ disc_stft_fake_adv_loss
|
||||
+ disc_period_real_adv_loss
|
||||
+ disc_period_fake_adv_loss
|
||||
+ disc_scale_real_adv_loss
|
||||
+ disc_scale_fake_adv_loss
|
||||
) * d_weight
|
||||
assert disc_loss.requires_grad is False
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
@ -753,13 +743,9 @@ def compute_validation_loss(
|
||||
return_sample=False,
|
||||
)
|
||||
gen_adv_loss = (
|
||||
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
||||
* g_weight
|
||||
/ 3
|
||||
)
|
||||
feature_loss = (
|
||||
feature_stft_loss + feature_period_loss + feature_scale_loss
|
||||
) / 3
|
||||
gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss
|
||||
) * g_weight
|
||||
feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss
|
||||
reconstruction_loss = (
|
||||
params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss
|
||||
)
|
||||
@ -836,20 +822,16 @@ def scan_pessimistic_batches_for_oom(
|
||||
forward_generator=False,
|
||||
)
|
||||
loss_d = (
|
||||
(
|
||||
disc_stft_real_adv_loss
|
||||
+ disc_stft_fake_adv_loss
|
||||
+ disc_period_real_adv_loss
|
||||
+ disc_period_fake_adv_loss
|
||||
+ disc_scale_real_adv_loss
|
||||
+ disc_scale_fake_adv_loss
|
||||
)
|
||||
* adopt_weight(
|
||||
params.lambda_adv,
|
||||
params.batch_idx_train,
|
||||
threshold=params.discriminator_iter_start,
|
||||
)
|
||||
/ 3
|
||||
disc_stft_real_adv_loss
|
||||
+ disc_stft_fake_adv_loss
|
||||
+ disc_period_real_adv_loss
|
||||
+ disc_period_fake_adv_loss
|
||||
+ disc_scale_real_adv_loss
|
||||
+ disc_scale_fake_adv_loss
|
||||
) * adopt_weight(
|
||||
params.lambda_adv,
|
||||
params.batch_idx_train,
|
||||
threshold=params.discriminator_iter_start,
|
||||
)
|
||||
optimizer_d.zero_grad()
|
||||
loss_d.backward()
|
||||
@ -879,7 +861,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
params.batch_idx_train,
|
||||
threshold=params.discriminator_iter_start,
|
||||
)
|
||||
/ 3
|
||||
+ params.lambda_rec
|
||||
* (
|
||||
params.lambda_wav * wav_reconstruction_loss
|
||||
@ -962,9 +943,17 @@ def run(rank, world_size, args):
|
||||
logging.info(f"Number of parameters in decoder: {num_param_d}")
|
||||
num_param_q = sum([p.numel() for p in quantizer.parameters()])
|
||||
logging.info(f"Number of parameters in quantizer: {num_param_q}")
|
||||
num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()])
|
||||
num_param_ds = (
|
||||
sum([p.numel() for p in multi_scale_discriminator.parameters()])
|
||||
if multi_scale_discriminator is not None
|
||||
else 0
|
||||
)
|
||||
logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}")
|
||||
num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()])
|
||||
num_param_dp = (
|
||||
sum([p.numel() for p in multi_period_discriminator.parameters()])
|
||||
if multi_period_discriminator is not None
|
||||
else 0
|
||||
)
|
||||
logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}")
|
||||
num_param_dstft = sum(
|
||||
[p.numel() for p in multi_scale_stft_discriminator.parameters()]
|
||||
@ -998,12 +987,15 @@ def run(rank, world_size, args):
|
||||
lr=params.lr,
|
||||
betas=(0.5, 0.9),
|
||||
)
|
||||
discriminator_params = [
|
||||
multi_scale_stft_discriminator.parameters(),
|
||||
]
|
||||
if multi_scale_discriminator is not None:
|
||||
discriminator_params.append(multi_scale_discriminator.parameters())
|
||||
if multi_period_discriminator is not None:
|
||||
discriminator_params.append(multi_period_discriminator.parameters())
|
||||
optimizer_d = torch.optim.AdamW(
|
||||
itertools.chain(
|
||||
multi_scale_stft_discriminator.parameters(),
|
||||
multi_scale_discriminator.parameters(),
|
||||
multi_period_discriminator.parameters(),
|
||||
),
|
||||
itertools.chain(*discriminator_params),
|
||||
lr=params.lr,
|
||||
betas=(0.5, 0.9),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user