mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
making MSD and MPD optional
This commit is contained in:
parent
f9340cc5d7
commit
e788bb4853
@ -1,6 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -25,8 +25,8 @@ class Encodec(nn.Module):
|
|||||||
quantizer: nn.Module,
|
quantizer: nn.Module,
|
||||||
decoder: nn.Module,
|
decoder: nn.Module,
|
||||||
multi_scale_discriminator: nn.Module,
|
multi_scale_discriminator: nn.Module,
|
||||||
multi_period_discriminator: nn.Module,
|
multi_period_discriminator: Optional[nn.Module] = None,
|
||||||
multi_scale_stft_discriminator: nn.Module,
|
multi_scale_stft_discriminator: Optional[nn.Module] = None,
|
||||||
cache_generator_outputs: bool = False,
|
cache_generator_outputs: bool = False,
|
||||||
):
|
):
|
||||||
super(Encodec, self).__init__()
|
super(Encodec, self).__init__()
|
||||||
@ -113,28 +113,42 @@ class Encodec(nn.Module):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# do not store discriminator gradient in generator turn
|
# do not store discriminator gradient in generator turn
|
||||||
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
|
y, fmap = self.multi_scale_stft_discriminator(speech.contiguous())
|
||||||
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
|
||||||
speech.contiguous(),
|
gen_period_adv_loss = torch.tensor(0.0)
|
||||||
speech_hat.contiguous(),
|
feature_period_loss = torch.tensor(0.0)
|
||||||
)
|
if self.multi_period_discriminator is not None:
|
||||||
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
||||||
speech.contiguous(),
|
speech.contiguous(),
|
||||||
speech_hat.contiguous(),
|
speech_hat.contiguous(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gen_scale_adv_loss = torch.tensor(0.0)
|
||||||
|
feature_scale_loss = torch.tensor(0.0)
|
||||||
|
if self.multi_scale_discriminator is not None:
|
||||||
|
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
||||||
|
speech.contiguous(),
|
||||||
|
speech_hat.contiguous(),
|
||||||
|
)
|
||||||
|
|
||||||
# calculate losses
|
# calculate losses
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat)
|
gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat)
|
||||||
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
|
|
||||||
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)
|
if self.multi_period_discriminator is not None:
|
||||||
|
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
|
||||||
|
if self.multi_scale_discriminator is not None:
|
||||||
|
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)
|
||||||
|
|
||||||
feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat)
|
feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat)
|
||||||
feature_period_loss = self.feature_match_loss(
|
|
||||||
feats=fmap_p, feats_hat=fmap_p_hat
|
if self.multi_period_discriminator is not None:
|
||||||
)
|
feature_period_loss = self.feature_match_loss(
|
||||||
feature_scale_loss = self.feature_match_loss(
|
feats=fmap_p, feats_hat=fmap_p_hat
|
||||||
feats=fmap_s, feats_hat=fmap_s_hat
|
)
|
||||||
)
|
if self.multi_scale_discriminator is not None:
|
||||||
|
feature_scale_loss = self.feature_match_loss(
|
||||||
|
feats=fmap_s, feats_hat=fmap_s_hat
|
||||||
|
)
|
||||||
|
|
||||||
wav_reconstruction_loss = self.wav_reconstruction_loss(
|
wav_reconstruction_loss = self.wav_reconstruction_loss(
|
||||||
x=speech, x_hat=speech_hat
|
x=speech, x_hat=speech_hat
|
||||||
@ -245,28 +259,44 @@ class Encodec(nn.Module):
|
|||||||
y_hat, fmap_hat = self.multi_scale_stft_discriminator(
|
y_hat, fmap_hat = self.multi_scale_stft_discriminator(
|
||||||
speech_hat.contiguous().detach()
|
speech_hat.contiguous().detach()
|
||||||
)
|
)
|
||||||
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
|
||||||
speech.contiguous(),
|
disc_period_real_adv_loss, disc_period_fake_adv_loss = torch.tensor(
|
||||||
speech_hat.contiguous().detach(),
|
0.0
|
||||||
)
|
), torch.tensor(0.0)
|
||||||
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
if self.multi_period_discriminator is not None:
|
||||||
speech.contiguous(),
|
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
||||||
speech_hat.contiguous().detach(),
|
speech.contiguous(),
|
||||||
)
|
speech_hat.contiguous().detach(),
|
||||||
|
)
|
||||||
|
|
||||||
|
disc_scale_real_adv_loss, disc_scale_fake_adv_loss = torch.tensor(
|
||||||
|
0.0
|
||||||
|
), torch.tensor(0.0)
|
||||||
|
if self.multi_scale_discriminator is not None:
|
||||||
|
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
||||||
|
speech.contiguous(),
|
||||||
|
speech_hat.contiguous().detach(),
|
||||||
|
)
|
||||||
# calculate losses
|
# calculate losses
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
(
|
(
|
||||||
disc_stft_real_adv_loss,
|
disc_stft_real_adv_loss,
|
||||||
disc_stft_fake_adv_loss,
|
disc_stft_fake_adv_loss,
|
||||||
) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat)
|
) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat)
|
||||||
(
|
if self.multi_period_discriminator is not None:
|
||||||
disc_period_real_adv_loss,
|
(
|
||||||
disc_period_fake_adv_loss,
|
disc_period_real_adv_loss,
|
||||||
) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat)
|
disc_period_fake_adv_loss,
|
||||||
(
|
) = self.discriminator_adversarial_loss(
|
||||||
disc_scale_real_adv_loss,
|
outputs=y_p, outputs_hat=y_p_hat
|
||||||
disc_scale_fake_adv_loss,
|
)
|
||||||
) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat)
|
if self.multi_scale_discriminator is not None:
|
||||||
|
(
|
||||||
|
disc_scale_real_adv_loss,
|
||||||
|
disc_scale_fake_adv_loss,
|
||||||
|
) = self.discriminator_adversarial_loss(
|
||||||
|
outputs=y_s, outputs_hat=y_s_hat
|
||||||
|
)
|
||||||
|
|
||||||
stats = dict(
|
stats = dict(
|
||||||
discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(),
|
discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(),
|
||||||
|
@ -313,8 +313,8 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
quantizer=quantizer,
|
quantizer=quantizer,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
multi_scale_discriminator=MultiScaleDiscriminator(),
|
multi_scale_discriminator=None,
|
||||||
multi_period_discriminator=MultiPeriodDiscriminator(),
|
multi_period_discriminator=None,
|
||||||
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
|
multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(
|
||||||
n_filters=params.stft_discriminator_n_filters
|
n_filters=params.stft_discriminator_n_filters
|
||||||
),
|
),
|
||||||
@ -456,17 +456,13 @@ def train_one_epoch(
|
|||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
disc_loss = (
|
disc_loss = (
|
||||||
(
|
disc_stft_real_adv_loss
|
||||||
disc_stft_real_adv_loss
|
+ disc_stft_fake_adv_loss
|
||||||
+ disc_stft_fake_adv_loss
|
+ disc_period_real_adv_loss
|
||||||
+ disc_period_real_adv_loss
|
+ disc_period_fake_adv_loss
|
||||||
+ disc_period_fake_adv_loss
|
+ disc_scale_real_adv_loss
|
||||||
+ disc_scale_real_adv_loss
|
+ disc_scale_fake_adv_loss
|
||||||
+ disc_scale_fake_adv_loss
|
) * d_weight
|
||||||
)
|
|
||||||
* d_weight
|
|
||||||
/ 3
|
|
||||||
)
|
|
||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
# update discriminator
|
# update discriminator
|
||||||
@ -499,13 +495,11 @@ def train_one_epoch(
|
|||||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||||
)
|
)
|
||||||
gen_adv_loss = (
|
gen_adv_loss = (
|
||||||
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss
|
||||||
* g_weight
|
) * g_weight
|
||||||
/ 3
|
|
||||||
)
|
|
||||||
feature_loss = (
|
feature_loss = (
|
||||||
feature_stft_loss + feature_period_loss + feature_scale_loss
|
feature_stft_loss + feature_period_loss + feature_scale_loss
|
||||||
) / 3
|
)
|
||||||
reconstruction_loss = (
|
reconstruction_loss = (
|
||||||
params.lambda_wav * wav_reconstruction_loss
|
params.lambda_wav * wav_reconstruction_loss
|
||||||
+ mel_reconstruction_loss
|
+ mel_reconstruction_loss
|
||||||
@ -714,17 +708,13 @@ def compute_validation_loss(
|
|||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
disc_loss = (
|
disc_loss = (
|
||||||
(
|
disc_stft_real_adv_loss
|
||||||
disc_stft_real_adv_loss
|
+ disc_stft_fake_adv_loss
|
||||||
+ disc_stft_fake_adv_loss
|
+ disc_period_real_adv_loss
|
||||||
+ disc_period_real_adv_loss
|
+ disc_period_fake_adv_loss
|
||||||
+ disc_period_fake_adv_loss
|
+ disc_scale_real_adv_loss
|
||||||
+ disc_scale_real_adv_loss
|
+ disc_scale_fake_adv_loss
|
||||||
+ disc_scale_fake_adv_loss
|
) * d_weight
|
||||||
)
|
|
||||||
* d_weight
|
|
||||||
/ 3
|
|
||||||
)
|
|
||||||
assert disc_loss.requires_grad is False
|
assert disc_loss.requires_grad is False
|
||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
@ -753,13 +743,9 @@ def compute_validation_loss(
|
|||||||
return_sample=False,
|
return_sample=False,
|
||||||
)
|
)
|
||||||
gen_adv_loss = (
|
gen_adv_loss = (
|
||||||
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss
|
||||||
* g_weight
|
) * g_weight
|
||||||
/ 3
|
feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss
|
||||||
)
|
|
||||||
feature_loss = (
|
|
||||||
feature_stft_loss + feature_period_loss + feature_scale_loss
|
|
||||||
) / 3
|
|
||||||
reconstruction_loss = (
|
reconstruction_loss = (
|
||||||
params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss
|
params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss
|
||||||
)
|
)
|
||||||
@ -836,20 +822,16 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
loss_d = (
|
loss_d = (
|
||||||
(
|
disc_stft_real_adv_loss
|
||||||
disc_stft_real_adv_loss
|
+ disc_stft_fake_adv_loss
|
||||||
+ disc_stft_fake_adv_loss
|
+ disc_period_real_adv_loss
|
||||||
+ disc_period_real_adv_loss
|
+ disc_period_fake_adv_loss
|
||||||
+ disc_period_fake_adv_loss
|
+ disc_scale_real_adv_loss
|
||||||
+ disc_scale_real_adv_loss
|
+ disc_scale_fake_adv_loss
|
||||||
+ disc_scale_fake_adv_loss
|
) * adopt_weight(
|
||||||
)
|
params.lambda_adv,
|
||||||
* adopt_weight(
|
params.batch_idx_train,
|
||||||
params.lambda_adv,
|
threshold=params.discriminator_iter_start,
|
||||||
params.batch_idx_train,
|
|
||||||
threshold=params.discriminator_iter_start,
|
|
||||||
)
|
|
||||||
/ 3
|
|
||||||
)
|
)
|
||||||
optimizer_d.zero_grad()
|
optimizer_d.zero_grad()
|
||||||
loss_d.backward()
|
loss_d.backward()
|
||||||
@ -879,7 +861,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
threshold=params.discriminator_iter_start,
|
threshold=params.discriminator_iter_start,
|
||||||
)
|
)
|
||||||
/ 3
|
|
||||||
+ params.lambda_rec
|
+ params.lambda_rec
|
||||||
* (
|
* (
|
||||||
params.lambda_wav * wav_reconstruction_loss
|
params.lambda_wav * wav_reconstruction_loss
|
||||||
@ -962,9 +943,17 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Number of parameters in decoder: {num_param_d}")
|
logging.info(f"Number of parameters in decoder: {num_param_d}")
|
||||||
num_param_q = sum([p.numel() for p in quantizer.parameters()])
|
num_param_q = sum([p.numel() for p in quantizer.parameters()])
|
||||||
logging.info(f"Number of parameters in quantizer: {num_param_q}")
|
logging.info(f"Number of parameters in quantizer: {num_param_q}")
|
||||||
num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()])
|
num_param_ds = (
|
||||||
|
sum([p.numel() for p in multi_scale_discriminator.parameters()])
|
||||||
|
if multi_scale_discriminator is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}")
|
logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}")
|
||||||
num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()])
|
num_param_dp = (
|
||||||
|
sum([p.numel() for p in multi_period_discriminator.parameters()])
|
||||||
|
if multi_period_discriminator is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}")
|
logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}")
|
||||||
num_param_dstft = sum(
|
num_param_dstft = sum(
|
||||||
[p.numel() for p in multi_scale_stft_discriminator.parameters()]
|
[p.numel() for p in multi_scale_stft_discriminator.parameters()]
|
||||||
@ -998,12 +987,15 @@ def run(rank, world_size, args):
|
|||||||
lr=params.lr,
|
lr=params.lr,
|
||||||
betas=(0.5, 0.9),
|
betas=(0.5, 0.9),
|
||||||
)
|
)
|
||||||
|
discriminator_params = [
|
||||||
|
multi_scale_stft_discriminator.parameters(),
|
||||||
|
]
|
||||||
|
if multi_scale_discriminator is not None:
|
||||||
|
discriminator_params.append(multi_scale_discriminator.parameters())
|
||||||
|
if multi_period_discriminator is not None:
|
||||||
|
discriminator_params.append(multi_period_discriminator.parameters())
|
||||||
optimizer_d = torch.optim.AdamW(
|
optimizer_d = torch.optim.AdamW(
|
||||||
itertools.chain(
|
itertools.chain(*discriminator_params),
|
||||||
multi_scale_stft_discriminator.parameters(),
|
|
||||||
multi_scale_discriminator.parameters(),
|
|
||||||
multi_period_discriminator.parameters(),
|
|
||||||
),
|
|
||||||
lr=params.lr,
|
lr=params.lr,
|
||||||
betas=(0.5, 0.9),
|
betas=(0.5, 0.9),
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user