refactored loss functions

This commit is contained in:
JinZr 2024-10-05 23:11:43 +08:00
parent 1e65a976d0
commit f9340cc5d7
5 changed files with 620 additions and 186 deletions

View File

@ -139,7 +139,7 @@ class LibriTTSCodecDataModule:
group.add_argument(
"--num-workers",
type=int,
default=2,
default=8,
help="The number of training dataloader workers that "
"collect the batches.",
)

View File

@ -4,7 +4,13 @@ from typing import List
import numpy as np
import torch
from loss import loss_dis, loss_g
from loss import (
DiscriminatorAdversarialLoss,
FeatureMatchLoss,
GeneratorAdversarialLoss,
MelSpectrogramReconstructionLoss,
WavReconstructionLoss,
)
from torch import nn
from torch.cuda.amp import autocast
@ -47,11 +53,23 @@ class Encodec(nn.Module):
self.cache_generator_outputs = cache_generator_outputs
self._cache = None
# construct loss functions
self.generator_adversarial_loss = GeneratorAdversarialLoss(
average_by_discriminators=True, loss_type="hinge"
)
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
average_by_discriminators=True, loss_type="hinge"
)
self.feature_match_loss = FeatureMatchLoss(average_by_layers=False)
self.wav_reconstruction_loss = WavReconstructionLoss()
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
sampling_rate=self.sampling_rate
)
def _forward_generator(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
global_step: int,
return_sample: bool = False,
):
"""Perform generator forward.
@ -59,7 +77,6 @@ class Encodec(nn.Module):
Args:
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
global_step (int): Global step.
return_sample (bool): Return the generator output.
Returns:
@ -107,33 +124,56 @@ class Encodec(nn.Module):
# calculate losses
with autocast(enabled=False):
loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
commit_loss,
speech,
speech_hat,
fmap,
fmap_hat,
y,
y_hat,
global_step,
y_p,
y_p_hat,
y_s,
y_s_hat,
fmap_p,
fmap_p_hat,
fmap_s,
fmap_s_hat,
args=self.params,
gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat)
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)
feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat)
feature_period_loss = self.feature_match_loss(
feats=fmap_p, feats_hat=fmap_p_hat
)
feature_scale_loss = self.feature_match_loss(
feats=fmap_s, feats_hat=fmap_s_hat
)
wav_reconstruction_loss = self.wav_reconstruction_loss(
x=speech, x_hat=speech_hat
)
mel_reconstruction_loss = self.mel_reconstruction_loss(
x=speech, x_hat=speech_hat
)
# loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
# commit_loss,
# speech,
# speech_hat,
# fmap,
# fmap_hat,
# y,
# y_hat,
# y_p,
# y_p_hat,
# y_s,
# y_s_hat,
# fmap_p,
# fmap_p_hat,
# fmap_s,
# fmap_s_hat,
# args=self.params,
# )
stats = dict(
generator_loss=loss.item(),
generator_reconstruction_loss=rec_loss.item(),
generator_feature_loss=feat_loss.item(),
generator_adv_loss=adv_loss.item(),
# generator_loss=loss.item(),
generator_wav_reconstruction_loss=wav_reconstruction_loss.item(),
generator_mel_reconstruction_loss=mel_reconstruction_loss.item(),
generator_feature_stft_loss=feature_stft_loss.item(),
generator_feature_period_loss=feature_period_loss.item(),
generator_feature_scale_loss=feature_scale_loss.item(),
generator_stft_adv_loss=gen_stft_adv_loss.item(),
generator_period_adv_loss=gen_period_adv_loss.item(),
generator_scale_adv_loss=gen_scale_adv_loss.item(),
generator_commit_loss=commit_loss.item(),
d_weight=d_weight.item(),
# d_weight=d_weight.item(),
)
if return_sample:
@ -147,19 +187,28 @@ class Encodec(nn.Module):
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
return (
commit_loss,
gen_stft_adv_loss,
gen_period_adv_loss,
gen_scale_adv_loss,
feature_stft_loss,
feature_period_loss,
feature_scale_loss,
wav_reconstruction_loss,
mel_reconstruction_loss,
stats,
)
def _forward_discriminator(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
global_step: int,
):
"""
Args:
speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,).
global_step (int): Global step.
Returns:
* loss (Tensor): Loss scalar tensor.
@ -206,37 +255,46 @@ class Encodec(nn.Module):
)
# calculate losses
with autocast(enabled=False):
loss = loss_dis(
y,
y_hat,
fmap,
fmap_hat,
y_p,
y_p_hat,
fmap_p,
fmap_p_hat,
y_s,
y_s_hat,
fmap_s,
fmap_s_hat,
global_step,
args=self.params,
)
(
disc_stft_real_adv_loss,
disc_stft_fake_adv_loss,
) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat)
(
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat)
(
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat)
stats = dict(
discriminator_loss=loss.item(),
discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(),
discriminator_period_real_adv_loss=disc_period_real_adv_loss.item(),
discriminator_scale_real_adv_loss=disc_scale_real_adv_loss.item(),
discriminator_stft_fake_adv_loss=disc_stft_fake_adv_loss.item(),
discriminator_period_fake_adv_loss=disc_period_fake_adv_loss.item(),
discriminator_scale_fake_adv_loss=disc_scale_fake_adv_loss.item(),
)
# reset cache
if reuse_cache or not self.training:
self._cache = None
return loss, stats
return (
disc_stft_real_adv_loss,
disc_stft_fake_adv_loss,
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
stats,
)
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
global_step: int,
return_sample: bool,
forward_generator: bool,
):
@ -244,14 +302,12 @@ class Encodec(nn.Module):
return self._forward_generator(
speech=speech,
speech_lengths=speech_lengths,
global_step=global_step,
return_sample=return_sample,
)
else:
return self._forward_discriminator(
speech=speech,
speech_lengths=speech_lengths,
global_step=global_step,
)
def encode(self, x, target_bw=None, st=None):

View File

@ -71,7 +71,7 @@ def get_parser():
parser.add_argument(
"--target-bw",
type=float,
default=7.5,
default=24,
help="The target bandwidth for the generator",
)

View File

@ -1,8 +1,310 @@
from typing import List, Tuple, Union
import torch
import torch.nn.functional as F
from lhotse.features.kaldi import Wav2LogFilterBank
from torchaudio.transforms import MelSpectrogram
class GeneratorAdversarialLoss(torch.nn.Module):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "hinge",
):
"""Initialize GeneratorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(
self,
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""Calcualate generator adversarial loss.
Args:
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs..
Returns:
Tensor: Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_ = outputs_[-1]
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _hinge_loss(self, x):
return F.relu(1 - x).mean()
class DiscriminatorAdversarialLoss(torch.nn.Module):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "hinge",
):
"""Initialize DiscriminatorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
else:
self.fake_criterion = self._hinge_fake_loss
self.real_criterion = self._hinge_real_loss
def forward(
self,
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from generator.
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_ones(x.size()))
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x.new_ones(x.size()) - x).mean()
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x.new_ones(x.size()) + x).mean()
class FeatureMatchLoss(torch.nn.Module):
"""Feature matching loss module."""
def __init__(
self,
average_by_layers: bool = True,
average_by_discriminators: bool = True,
include_final_outputs: bool = False,
):
"""Initialize FeatureMatchLoss module.
Args:
average_by_layers (bool): Whether to average the loss by the number
of layers.
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
include_final_outputs (bool): Whether to include the final output of
each discriminator for loss calculation.
"""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
self.include_final_outputs = include_final_outputs
def forward(
self,
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
) -> torch.Tensor:
"""Calculate feature matching loss.
Args:
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from generator's outputs.
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from groundtruth..
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
if not self.include_final_outputs:
feats_hat_ = feats_hat_[:-1]
feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
class MelSpectrogramReconstructionLoss(torch.nn.Module):
"""Mel Spec Reconstruction loss."""
def __init__(
self,
sampling_rate: int = 22050,
n_mels: int = 64,
use_fft_mag: bool = True,
return_mel: bool = False,
):
super().__init__()
self.wav_to_specs = []
for i in range(5, 12):
s = 2**i
# self.wav_to_specs.append(
# Wav2LogFilterBank(
# sampling_rate=sampling_rate,
# frame_length=s,
# frame_shift=s // 4,
# use_fft_mag=use_fft_mag,
# num_filters=n_mels,
# )
# )
self.wav_to_specs.append(
MelSpectrogram(
sample_rate=sampling_rate,
n_fft=max(s, 512),
win_length=s,
hop_length=s // 4,
n_mels=n_mels,
)
)
self.return_mel = return_mel
def forward(
self,
x_hat: torch.Tensor,
x: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
"""Calculate Mel-spectrogram loss.
Args:
x_hat (Tensor): Generated waveform tensor (B, 1, T).
x (Tensor): Groundtruth waveform tensor (B, 1, T).
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
waveform.
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_loss = 0.0
for i, wav_to_spec in enumerate(self.wav_to_specs):
s = 2 ** (i + 5)
wav_to_spec.to(x.device)
mel_hat = wav_to_spec(x_hat.squeeze(1))
mel = wav_to_spec(x.squeeze(1))
alpha = (s / 2) ** 0.5
mel_loss += F.l1_loss(mel_hat, mel) + alpha * F.mse_loss(mel_hat, mel)
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
# mel = self.wav_to_spec(x.squeeze(1))
# mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel)
if self.return_mel:
return mel_loss, (mel_hat, mel)
return mel_loss
class WavReconstructionLoss(torch.nn.Module):
"""Wav Reconstruction loss."""
def __init__(self):
super().__init__()
def forward(
self,
x_hat: torch.Tensor,
x: torch.Tensor,
) -> torch.Tensor:
"""Calculate wav loss.
Args:
x_hat (Tensor): Generated waveform tensor (B, 1, T).
x (Tensor): Groundtruth waveform tensor (B, 1, T).
Returns:
Tensor: Wav loss value.
"""
wav_loss = F.mse_loss(x, x_hat)
return wav_loss
def adversarial_g_loss(y_disc_gen):
"""Hinge loss"""
loss = 0.0
@ -63,88 +365,12 @@ def reconstruction_loss(x, x_hat, args, eps=1e-7):
return L
def criterion_d(
y_disc_r,
y_disc_gen,
fmap_r_det,
fmap_gen_det,
y_df_hat_r,
y_df_hat_g,
fmap_f_r,
fmap_f_g,
y_ds_hat_r,
y_ds_hat_g,
fmap_s_r,
fmap_s_g,
):
"""Hinge Loss"""
loss = 0.0
loss1 = 0.0
loss2 = 0.0
loss3 = 0.0
for i in range(len(y_disc_r)):
loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[i]).mean()
for i in range(len(y_df_hat_r)):
loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[i]).mean()
for i in range(len(y_ds_hat_r)):
loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[i]).mean()
loss = (
loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / len(y_ds_hat_r)
) / 3.0
return loss
def criterion_g(
commit_loss,
x,
G_x,
fmap_r,
fmap_gen,
y_disc_r,
y_disc_gen,
y_df_hat_r,
y_df_hat_g,
fmap_f_r,
fmap_f_g,
y_ds_hat_r,
y_ds_hat_g,
fmap_s_r,
fmap_s_g,
args,
):
adv_g_loss = adversarial_g_loss(y_disc_gen)
feat_loss = (
feature_loss(fmap_r, fmap_gen)
+ sim_loss(y_disc_r, y_disc_gen)
+ feature_loss(fmap_f_r, fmap_f_g)
+ sim_loss(y_df_hat_r, y_df_hat_g)
+ feature_loss(fmap_s_r, fmap_s_g)
+ sim_loss(y_ds_hat_r, y_ds_hat_g)
) / 3.0
rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args)
total_loss = (
args.lambda_com * commit_loss
+ args.lambda_adv * adv_g_loss
+ args.lambda_feat * feat_loss
+ args.lambda_rec * rec_loss
)
return total_loss, adv_g_loss, feat_loss, rec_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def adopt_dis_weight(weight, global_step, threshold=0, value=0.0):
if global_step % 3 == 0:
weight = value
return weight
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
@ -166,7 +392,6 @@ def loss_g(
fmap_hat,
y,
y_hat,
global_step,
y_df,
y_df_hat,
y_ds,
@ -215,9 +440,10 @@ def loss_g(
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
d_weight = torch.tensor(1.0)
disc_factor = adopt_weight(
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
)
# disc_factor = adopt_weight(
# args.lambda_adv, global_step, threshold=args.discriminator_iter_start
# )
disc_factor = 1
if disc_factor == 0.0:
fm_loss_wt = 0
else:
@ -232,37 +458,9 @@ def loss_g(
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
def loss_dis(
y,
y_hat,
fmap,
fmap_hat,
y_df,
y_df_hat,
fmap_f,
fmap_f_hat,
y_ds,
y_ds_hat,
fmap_s,
fmap_s_hat,
global_step,
args,
):
disc_factor = adopt_weight(
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
)
d_loss = disc_factor * criterion_d(
y,
y_hat,
fmap,
fmap_hat,
y_df,
y_df_hat,
fmap_f,
fmap_f_hat,
y_ds,
y_ds_hat,
fmap_s,
fmap_s_hat,
)
return d_loss
if __name__ == "__main__":
la = FeatureMatchLoss(average_by_layers=False, average_by_discriminators=False)
aa = [torch.rand(192, 192) for _ in range(3)]
bb = [torch.rand(192, 192) for _ in range(3)]
print(la(bb, aa))
print(feature_loss(aa, bb))

View File

@ -15,12 +15,13 @@ from codec_datamodule import LibriTTSCodecDataModule
from encodec import Encodec
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from loss import adopt_weight
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from utils import MetricsTracker, plot_feature, save_checkpoint
from utils import MetricsTracker, save_checkpoint
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint
@ -250,11 +251,26 @@ def get_model(params: AttributeDict) -> nn.Module:
from modules.seanet import SEANetDecoder, SEANetEncoder
from quantization import ResidualVectorQuantizer
# generator_params = {
# "generator_n_filters": 32,
# "dimension": 512,
# "ratios": [2, 2, 2, 4],
# "target_bandwidths": [7.5, 15],
# "bins": 1024,
# }
# discriminator_params = {
# "stft_discriminator_n_filters": 32,
# "discriminator_iter_start": 500,
# }
# inference_params = {
# "target_bw": 7.5,
# }
generator_params = {
"generator_n_filters": 32,
"dimension": 512,
"ratios": [2, 2, 2, 4],
"target_bandwidths": [7.5, 15],
"ratios": [8, 5, 4, 2],
"target_bandwidths": [1.5, 3, 6, 12, 24],
"bins": 1024,
}
discriminator_params = {
@ -262,7 +278,7 @@ def get_model(params: AttributeDict) -> nn.Module:
"discriminator_iter_start": 500,
}
inference_params = {
"target_bw": 7.5,
"target_bw": 12,
}
params.update(generator_params)
@ -419,36 +435,93 @@ def train_one_epoch(
try:
with autocast(enabled=params.use_fp16):
d_weight = adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
# forward discriminator
loss_d, stats_d = model(
(
disc_stft_real_adv_loss,
disc_stft_fake_adv_loss,
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
stats_d,
) = model(
speech=audio,
speech_lengths=audio_lens,
global_step=params.batch_idx_train,
return_sample=False,
forward_generator=False,
)
disc_loss = (
(
disc_stft_real_adv_loss
+ disc_stft_fake_adv_loss
+ disc_period_real_adv_loss
+ disc_period_fake_adv_loss
+ disc_scale_real_adv_loss
+ disc_scale_fake_adv_loss
)
* d_weight
/ 3
)
for k, v in stats_d.items():
loss_info[k] = v * batch_size
# update discriminator
optimizer_d.zero_grad()
scaler.scale(loss_d).backward()
scaler.scale(disc_loss).backward()
scaler.step(optimizer_d)
with autocast(enabled=params.use_fp16):
g_weight = adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
# forward generator
loss_g, stats_g = model(
(
commit_loss,
gen_stft_adv_loss,
gen_period_adv_loss,
gen_scale_adv_loss,
feature_stft_loss,
feature_period_loss,
feature_scale_loss,
wav_reconstruction_loss,
mel_reconstruction_loss,
stats_g,
) = model(
speech=audio,
speech_lengths=audio_lens,
global_step=params.batch_idx_train,
forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0,
)
gen_adv_loss = (
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
* g_weight
/ 3
)
feature_loss = (
feature_stft_loss + feature_period_loss + feature_scale_loss
) / 3
reconstruction_loss = (
params.lambda_wav * wav_reconstruction_loss
+ mel_reconstruction_loss
)
gen_loss = (
gen_adv_loss
+ params.lambda_rec * reconstruction_loss
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
+ params.lambda_com * commit_loss
)
for k, v in stats_g.items():
if "returned_sample" not in k:
loss_info[k] = v * batch_size
# update generator
optimizer_g.zero_grad()
scaler.scale(loss_g).backward()
scaler.scale(gen_loss).backward()
scaler.step(optimizer_g)
scaler.update()
@ -619,27 +692,84 @@ def compute_validation_loss(
loss_info = MetricsTracker()
loss_info["samples"] = batch_size
d_weight = adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
# forward discriminator
loss_d, stats_d = model(
(
disc_stft_real_adv_loss,
disc_stft_fake_adv_loss,
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
stats_d,
) = model(
speech=audio,
speech_lengths=audio_lens,
global_step=params.batch_idx_train,
return_sample=False,
forward_generator=False,
)
assert loss_d.requires_grad is False
disc_loss = (
(
disc_stft_real_adv_loss
+ disc_stft_fake_adv_loss
+ disc_period_real_adv_loss
+ disc_period_fake_adv_loss
+ disc_scale_real_adv_loss
+ disc_scale_fake_adv_loss
)
* d_weight
/ 3
)
assert disc_loss.requires_grad is False
for k, v in stats_d.items():
loss_info[k] = v * batch_size
g_weight = adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
# forward generator
loss_g, stats_g = model(
(
commit_loss,
gen_stft_adv_loss,
gen_period_adv_loss,
gen_scale_adv_loss,
feature_stft_loss,
feature_period_loss,
feature_scale_loss,
wav_reconstruction_loss,
mel_reconstruction_loss,
stats_g,
) = model(
speech=audio,
speech_lengths=audio_lens,
global_step=params.batch_idx_train,
forward_generator=True,
return_sample=False,
)
assert loss_g.requires_grad is False
gen_adv_loss = (
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
* g_weight
/ 3
)
feature_loss = (
feature_stft_loss + feature_period_loss + feature_scale_loss
) / 3
reconstruction_loss = (
params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss
)
gen_loss = (
gen_adv_loss
+ params.lambda_rec * reconstruction_loss
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
+ params.lambda_com * commit_loss
)
assert gen_loss.requires_grad is False
for k, v in stats_g.items():
if "returned_sample" not in k:
loss_info[k] = v * batch_size
@ -691,24 +821,74 @@ def scan_pessimistic_batches_for_oom(
try:
# for discriminator
with autocast(enabled=params.use_fp16):
loss_d, stats_d = model(
(
disc_stft_real_adv_loss,
disc_stft_fake_adv_loss,
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
stats_d,
) = model(
speech=audio,
speech_lengths=audio_lens,
global_step=params.batch_idx_train,
return_sample=False,
forward_generator=False,
)
loss_d = (
(
disc_stft_real_adv_loss
+ disc_stft_fake_adv_loss
+ disc_period_real_adv_loss
+ disc_period_fake_adv_loss
+ disc_scale_real_adv_loss
+ disc_scale_fake_adv_loss
)
* adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
/ 3
)
optimizer_d.zero_grad()
loss_d.backward()
# for generator
with autocast(enabled=params.use_fp16):
loss_g, stats_g = model(
(
commit_loss,
gen_stft_adv_loss,
gen_period_adv_loss,
gen_scale_adv_loss,
feature_stft_loss,
feature_period_loss,
feature_scale_loss,
wav_reconstruction_loss,
mel_reconstruction_loss,
stats_g,
) = model(
speech=audio,
speech_lengths=audio_lens,
forward_generator=True,
global_step=params.batch_idx_train,
return_sample=False,
)
loss_g = (
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
* adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
/ 3
+ params.lambda_rec
* (
params.lambda_wav * wav_reconstruction_loss
+ mel_reconstruction_loss
)
+ params.lambda_feat
* (feature_stft_loss + feature_period_loss + feature_scale_loss)
+ params.lambda_com * commit_loss
)
optimizer_g.zero_grad()
loss_g.backward()
except Exception as e: