refactored loss functions

This commit is contained in:
JinZr 2024-10-05 23:11:43 +08:00
parent 1e65a976d0
commit f9340cc5d7
5 changed files with 620 additions and 186 deletions

View File

@ -139,7 +139,7 @@ class LibriTTSCodecDataModule:
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=8,
help="The number of training dataloader workers that " help="The number of training dataloader workers that "
"collect the batches.", "collect the batches.",
) )

View File

@ -4,7 +4,13 @@ from typing import List
import numpy as np import numpy as np
import torch import torch
from loss import loss_dis, loss_g from loss import (
DiscriminatorAdversarialLoss,
FeatureMatchLoss,
GeneratorAdversarialLoss,
MelSpectrogramReconstructionLoss,
WavReconstructionLoss,
)
from torch import nn from torch import nn
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
@ -47,11 +53,23 @@ class Encodec(nn.Module):
self.cache_generator_outputs = cache_generator_outputs self.cache_generator_outputs = cache_generator_outputs
self._cache = None self._cache = None
# construct loss functions
self.generator_adversarial_loss = GeneratorAdversarialLoss(
average_by_discriminators=True, loss_type="hinge"
)
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
average_by_discriminators=True, loss_type="hinge"
)
self.feature_match_loss = FeatureMatchLoss(average_by_layers=False)
self.wav_reconstruction_loss = WavReconstructionLoss()
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
sampling_rate=self.sampling_rate
)
def _forward_generator( def _forward_generator(
self, self,
speech: torch.Tensor, speech: torch.Tensor,
speech_lengths: torch.Tensor, speech_lengths: torch.Tensor,
global_step: int,
return_sample: bool = False, return_sample: bool = False,
): ):
"""Perform generator forward. """Perform generator forward.
@ -59,7 +77,6 @@ class Encodec(nn.Module):
Args: Args:
speech (Tensor): Speech waveform tensor (B, T_wav). speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,). speech_lengths (Tensor): Speech length tensor (B,).
global_step (int): Global step.
return_sample (bool): Return the generator output. return_sample (bool): Return the generator output.
Returns: Returns:
@ -107,33 +124,56 @@ class Encodec(nn.Module):
# calculate losses # calculate losses
with autocast(enabled=False): with autocast(enabled=False):
loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g( gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat)
commit_loss, gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
speech, gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)
speech_hat,
fmap, feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat)
fmap_hat, feature_period_loss = self.feature_match_loss(
y, feats=fmap_p, feats_hat=fmap_p_hat
y_hat, )
global_step, feature_scale_loss = self.feature_match_loss(
y_p, feats=fmap_s, feats_hat=fmap_s_hat
y_p_hat,
y_s,
y_s_hat,
fmap_p,
fmap_p_hat,
fmap_s,
fmap_s_hat,
args=self.params,
) )
wav_reconstruction_loss = self.wav_reconstruction_loss(
x=speech, x_hat=speech_hat
)
mel_reconstruction_loss = self.mel_reconstruction_loss(
x=speech, x_hat=speech_hat
)
# loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
# commit_loss,
# speech,
# speech_hat,
# fmap,
# fmap_hat,
# y,
# y_hat,
# y_p,
# y_p_hat,
# y_s,
# y_s_hat,
# fmap_p,
# fmap_p_hat,
# fmap_s,
# fmap_s_hat,
# args=self.params,
# )
stats = dict( stats = dict(
generator_loss=loss.item(), # generator_loss=loss.item(),
generator_reconstruction_loss=rec_loss.item(), generator_wav_reconstruction_loss=wav_reconstruction_loss.item(),
generator_feature_loss=feat_loss.item(), generator_mel_reconstruction_loss=mel_reconstruction_loss.item(),
generator_adv_loss=adv_loss.item(), generator_feature_stft_loss=feature_stft_loss.item(),
generator_feature_period_loss=feature_period_loss.item(),
generator_feature_scale_loss=feature_scale_loss.item(),
generator_stft_adv_loss=gen_stft_adv_loss.item(),
generator_period_adv_loss=gen_period_adv_loss.item(),
generator_scale_adv_loss=gen_scale_adv_loss.item(),
generator_commit_loss=commit_loss.item(), generator_commit_loss=commit_loss.item(),
d_weight=d_weight.item(), # d_weight=d_weight.item(),
) )
if return_sample: if return_sample:
@ -147,19 +187,28 @@ class Encodec(nn.Module):
# reset cache # reset cache
if reuse_cache or not self.training: if reuse_cache or not self.training:
self._cache = None self._cache = None
return loss, stats return (
commit_loss,
gen_stft_adv_loss,
gen_period_adv_loss,
gen_scale_adv_loss,
feature_stft_loss,
feature_period_loss,
feature_scale_loss,
wav_reconstruction_loss,
mel_reconstruction_loss,
stats,
)
def _forward_discriminator( def _forward_discriminator(
self, self,
speech: torch.Tensor, speech: torch.Tensor,
speech_lengths: torch.Tensor, speech_lengths: torch.Tensor,
global_step: int,
): ):
""" """
Args: Args:
speech (Tensor): Speech waveform tensor (B, T_wav). speech (Tensor): Speech waveform tensor (B, T_wav).
speech_lengths (Tensor): Speech length tensor (B,). speech_lengths (Tensor): Speech length tensor (B,).
global_step (int): Global step.
Returns: Returns:
* loss (Tensor): Loss scalar tensor. * loss (Tensor): Loss scalar tensor.
@ -206,37 +255,46 @@ class Encodec(nn.Module):
) )
# calculate losses # calculate losses
with autocast(enabled=False): with autocast(enabled=False):
loss = loss_dis( (
y, disc_stft_real_adv_loss,
y_hat, disc_stft_fake_adv_loss,
fmap, ) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat)
fmap_hat, (
y_p, disc_period_real_adv_loss,
y_p_hat, disc_period_fake_adv_loss,
fmap_p, ) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat)
fmap_p_hat, (
y_s, disc_scale_real_adv_loss,
y_s_hat, disc_scale_fake_adv_loss,
fmap_s, ) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat)
fmap_s_hat,
global_step,
args=self.params,
)
stats = dict( stats = dict(
discriminator_loss=loss.item(), discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(),
discriminator_period_real_adv_loss=disc_period_real_adv_loss.item(),
discriminator_scale_real_adv_loss=disc_scale_real_adv_loss.item(),
discriminator_stft_fake_adv_loss=disc_stft_fake_adv_loss.item(),
discriminator_period_fake_adv_loss=disc_period_fake_adv_loss.item(),
discriminator_scale_fake_adv_loss=disc_scale_fake_adv_loss.item(),
) )
# reset cache # reset cache
if reuse_cache or not self.training: if reuse_cache or not self.training:
self._cache = None self._cache = None
return loss, stats return (
disc_stft_real_adv_loss,
disc_stft_fake_adv_loss,
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
stats,
)
def forward( def forward(
self, self,
speech: torch.Tensor, speech: torch.Tensor,
speech_lengths: torch.Tensor, speech_lengths: torch.Tensor,
global_step: int,
return_sample: bool, return_sample: bool,
forward_generator: bool, forward_generator: bool,
): ):
@ -244,14 +302,12 @@ class Encodec(nn.Module):
return self._forward_generator( return self._forward_generator(
speech=speech, speech=speech,
speech_lengths=speech_lengths, speech_lengths=speech_lengths,
global_step=global_step,
return_sample=return_sample, return_sample=return_sample,
) )
else: else:
return self._forward_discriminator( return self._forward_discriminator(
speech=speech, speech=speech,
speech_lengths=speech_lengths, speech_lengths=speech_lengths,
global_step=global_step,
) )
def encode(self, x, target_bw=None, st=None): def encode(self, x, target_bw=None, st=None):

View File

@ -71,7 +71,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--target-bw", "--target-bw",
type=float, type=float,
default=7.5, default=24,
help="The target bandwidth for the generator", help="The target bandwidth for the generator",
) )

View File

@ -1,8 +1,310 @@
from typing import List, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from lhotse.features.kaldi import Wav2LogFilterBank
from torchaudio.transforms import MelSpectrogram from torchaudio.transforms import MelSpectrogram
class GeneratorAdversarialLoss(torch.nn.Module):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "hinge",
):
"""Initialize GeneratorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(
self,
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""Calcualate generator adversarial loss.
Args:
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs..
Returns:
Tensor: Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_ = outputs_[-1]
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _hinge_loss(self, x):
return F.relu(1 - x).mean()
class DiscriminatorAdversarialLoss(torch.nn.Module):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "hinge",
):
"""Initialize DiscriminatorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
else:
self.fake_criterion = self._hinge_fake_loss
self.real_criterion = self._hinge_real_loss
def forward(
self,
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from generator.
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_ones(x.size()))
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x.new_ones(x.size()) - x).mean()
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(x.new_ones(x.size()) + x).mean()
class FeatureMatchLoss(torch.nn.Module):
"""Feature matching loss module."""
def __init__(
self,
average_by_layers: bool = True,
average_by_discriminators: bool = True,
include_final_outputs: bool = False,
):
"""Initialize FeatureMatchLoss module.
Args:
average_by_layers (bool): Whether to average the loss by the number
of layers.
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
include_final_outputs (bool): Whether to include the final output of
each discriminator for loss calculation.
"""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
self.include_final_outputs = include_final_outputs
def forward(
self,
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
) -> torch.Tensor:
"""Calculate feature matching loss.
Args:
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from generator's outputs.
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from groundtruth..
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
if not self.include_final_outputs:
feats_hat_ = feats_hat_[:-1]
feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
class MelSpectrogramReconstructionLoss(torch.nn.Module):
"""Mel Spec Reconstruction loss."""
def __init__(
self,
sampling_rate: int = 22050,
n_mels: int = 64,
use_fft_mag: bool = True,
return_mel: bool = False,
):
super().__init__()
self.wav_to_specs = []
for i in range(5, 12):
s = 2**i
# self.wav_to_specs.append(
# Wav2LogFilterBank(
# sampling_rate=sampling_rate,
# frame_length=s,
# frame_shift=s // 4,
# use_fft_mag=use_fft_mag,
# num_filters=n_mels,
# )
# )
self.wav_to_specs.append(
MelSpectrogram(
sample_rate=sampling_rate,
n_fft=max(s, 512),
win_length=s,
hop_length=s // 4,
n_mels=n_mels,
)
)
self.return_mel = return_mel
def forward(
self,
x_hat: torch.Tensor,
x: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
"""Calculate Mel-spectrogram loss.
Args:
x_hat (Tensor): Generated waveform tensor (B, 1, T).
x (Tensor): Groundtruth waveform tensor (B, 1, T).
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
waveform.
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_loss = 0.0
for i, wav_to_spec in enumerate(self.wav_to_specs):
s = 2 ** (i + 5)
wav_to_spec.to(x.device)
mel_hat = wav_to_spec(x_hat.squeeze(1))
mel = wav_to_spec(x.squeeze(1))
alpha = (s / 2) ** 0.5
mel_loss += F.l1_loss(mel_hat, mel) + alpha * F.mse_loss(mel_hat, mel)
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
# mel = self.wav_to_spec(x.squeeze(1))
# mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel)
if self.return_mel:
return mel_loss, (mel_hat, mel)
return mel_loss
class WavReconstructionLoss(torch.nn.Module):
"""Wav Reconstruction loss."""
def __init__(self):
super().__init__()
def forward(
self,
x_hat: torch.Tensor,
x: torch.Tensor,
) -> torch.Tensor:
"""Calculate wav loss.
Args:
x_hat (Tensor): Generated waveform tensor (B, 1, T).
x (Tensor): Groundtruth waveform tensor (B, 1, T).
Returns:
Tensor: Wav loss value.
"""
wav_loss = F.mse_loss(x, x_hat)
return wav_loss
def adversarial_g_loss(y_disc_gen): def adversarial_g_loss(y_disc_gen):
"""Hinge loss""" """Hinge loss"""
loss = 0.0 loss = 0.0
@ -63,88 +365,12 @@ def reconstruction_loss(x, x_hat, args, eps=1e-7):
return L return L
def criterion_d(
y_disc_r,
y_disc_gen,
fmap_r_det,
fmap_gen_det,
y_df_hat_r,
y_df_hat_g,
fmap_f_r,
fmap_f_g,
y_ds_hat_r,
y_ds_hat_g,
fmap_s_r,
fmap_s_g,
):
"""Hinge Loss"""
loss = 0.0
loss1 = 0.0
loss2 = 0.0
loss3 = 0.0
for i in range(len(y_disc_r)):
loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[i]).mean()
for i in range(len(y_df_hat_r)):
loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[i]).mean()
for i in range(len(y_ds_hat_r)):
loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[i]).mean()
loss = (
loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / len(y_ds_hat_r)
) / 3.0
return loss
def criterion_g(
commit_loss,
x,
G_x,
fmap_r,
fmap_gen,
y_disc_r,
y_disc_gen,
y_df_hat_r,
y_df_hat_g,
fmap_f_r,
fmap_f_g,
y_ds_hat_r,
y_ds_hat_g,
fmap_s_r,
fmap_s_g,
args,
):
adv_g_loss = adversarial_g_loss(y_disc_gen)
feat_loss = (
feature_loss(fmap_r, fmap_gen)
+ sim_loss(y_disc_r, y_disc_gen)
+ feature_loss(fmap_f_r, fmap_f_g)
+ sim_loss(y_df_hat_r, y_df_hat_g)
+ feature_loss(fmap_s_r, fmap_s_g)
+ sim_loss(y_ds_hat_r, y_ds_hat_g)
) / 3.0
rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args)
total_loss = (
args.lambda_com * commit_loss
+ args.lambda_adv * adv_g_loss
+ args.lambda_feat * feat_loss
+ args.lambda_rec * rec_loss
)
return total_loss, adv_g_loss, feat_loss, rec_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0): def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold: if global_step < threshold:
weight = value weight = value
return weight return weight
def adopt_dis_weight(weight, global_step, threshold=0, value=0.0):
if global_step % 3 == 0:
weight = value
return weight
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args): def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
if last_layer is not None: if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
@ -166,7 +392,6 @@ def loss_g(
fmap_hat, fmap_hat,
y, y,
y_hat, y_hat,
global_step,
y_df, y_df,
y_df_hat, y_df_hat,
y_ds, y_ds,
@ -215,9 +440,10 @@ def loss_g(
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0 feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
d_weight = torch.tensor(1.0) d_weight = torch.tensor(1.0)
disc_factor = adopt_weight( # disc_factor = adopt_weight(
args.lambda_adv, global_step, threshold=args.discriminator_iter_start # args.lambda_adv, global_step, threshold=args.discriminator_iter_start
) # )
disc_factor = 1
if disc_factor == 0.0: if disc_factor == 0.0:
fm_loss_wt = 0 fm_loss_wt = 0
else: else:
@ -232,37 +458,9 @@ def loss_g(
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
def loss_dis( if __name__ == "__main__":
y, la = FeatureMatchLoss(average_by_layers=False, average_by_discriminators=False)
y_hat, aa = [torch.rand(192, 192) for _ in range(3)]
fmap, bb = [torch.rand(192, 192) for _ in range(3)]
fmap_hat, print(la(bb, aa))
y_df, print(feature_loss(aa, bb))
y_df_hat,
fmap_f,
fmap_f_hat,
y_ds,
y_ds_hat,
fmap_s,
fmap_s_hat,
global_step,
args,
):
disc_factor = adopt_weight(
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
)
d_loss = disc_factor * criterion_d(
y,
y_hat,
fmap,
fmap_hat,
y_df,
y_df_hat,
fmap_f,
fmap_f_hat,
y_ds,
y_ds_hat,
fmap_s,
fmap_s_hat,
)
return d_loss

View File

@ -15,12 +15,13 @@ from codec_datamodule import LibriTTSCodecDataModule
from encodec import Encodec from encodec import Encodec
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from loss import adopt_weight
from torch import nn from torch import nn
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from utils import MetricsTracker, plot_feature, save_checkpoint from utils import MetricsTracker, save_checkpoint
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
@ -250,11 +251,26 @@ def get_model(params: AttributeDict) -> nn.Module:
from modules.seanet import SEANetDecoder, SEANetEncoder from modules.seanet import SEANetDecoder, SEANetEncoder
from quantization import ResidualVectorQuantizer from quantization import ResidualVectorQuantizer
# generator_params = {
# "generator_n_filters": 32,
# "dimension": 512,
# "ratios": [2, 2, 2, 4],
# "target_bandwidths": [7.5, 15],
# "bins": 1024,
# }
# discriminator_params = {
# "stft_discriminator_n_filters": 32,
# "discriminator_iter_start": 500,
# }
# inference_params = {
# "target_bw": 7.5,
# }
generator_params = { generator_params = {
"generator_n_filters": 32, "generator_n_filters": 32,
"dimension": 512, "dimension": 512,
"ratios": [2, 2, 2, 4], "ratios": [8, 5, 4, 2],
"target_bandwidths": [7.5, 15], "target_bandwidths": [1.5, 3, 6, 12, 24],
"bins": 1024, "bins": 1024,
} }
discriminator_params = { discriminator_params = {
@ -262,7 +278,7 @@ def get_model(params: AttributeDict) -> nn.Module:
"discriminator_iter_start": 500, "discriminator_iter_start": 500,
} }
inference_params = { inference_params = {
"target_bw": 7.5, "target_bw": 12,
} }
params.update(generator_params) params.update(generator_params)
@ -419,36 +435,93 @@ def train_one_epoch(
try: try:
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
d_weight = adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
# forward discriminator # forward discriminator
loss_d, stats_d = model( (
disc_stft_real_adv_loss,
disc_stft_fake_adv_loss,
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
stats_d,
) = model(
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
global_step=params.batch_idx_train,
return_sample=False, return_sample=False,
forward_generator=False, forward_generator=False,
) )
disc_loss = (
(
disc_stft_real_adv_loss
+ disc_stft_fake_adv_loss
+ disc_period_real_adv_loss
+ disc_period_fake_adv_loss
+ disc_scale_real_adv_loss
+ disc_scale_fake_adv_loss
)
* d_weight
/ 3
)
for k, v in stats_d.items(): for k, v in stats_d.items():
loss_info[k] = v * batch_size loss_info[k] = v * batch_size
# update discriminator # update discriminator
optimizer_d.zero_grad() optimizer_d.zero_grad()
scaler.scale(loss_d).backward() scaler.scale(disc_loss).backward()
scaler.step(optimizer_d) scaler.step(optimizer_d)
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
g_weight = adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
# forward generator # forward generator
loss_g, stats_g = model( (
commit_loss,
gen_stft_adv_loss,
gen_period_adv_loss,
gen_scale_adv_loss,
feature_stft_loss,
feature_period_loss,
feature_scale_loss,
wav_reconstruction_loss,
mel_reconstruction_loss,
stats_g,
) = model(
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
global_step=params.batch_idx_train,
forward_generator=True, forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0, return_sample=params.batch_idx_train % params.log_interval == 0,
) )
gen_adv_loss = (
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
* g_weight
/ 3
)
feature_loss = (
feature_stft_loss + feature_period_loss + feature_scale_loss
) / 3
reconstruction_loss = (
params.lambda_wav * wav_reconstruction_loss
+ mel_reconstruction_loss
)
gen_loss = (
gen_adv_loss
+ params.lambda_rec * reconstruction_loss
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
+ params.lambda_com * commit_loss
)
for k, v in stats_g.items(): for k, v in stats_g.items():
if "returned_sample" not in k: if "returned_sample" not in k:
loss_info[k] = v * batch_size loss_info[k] = v * batch_size
# update generator # update generator
optimizer_g.zero_grad() optimizer_g.zero_grad()
scaler.scale(loss_g).backward() scaler.scale(gen_loss).backward()
scaler.step(optimizer_g) scaler.step(optimizer_g)
scaler.update() scaler.update()
@ -619,27 +692,84 @@ def compute_validation_loss(
loss_info = MetricsTracker() loss_info = MetricsTracker()
loss_info["samples"] = batch_size loss_info["samples"] = batch_size
d_weight = adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
# forward discriminator # forward discriminator
loss_d, stats_d = model( (
disc_stft_real_adv_loss,
disc_stft_fake_adv_loss,
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
stats_d,
) = model(
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
global_step=params.batch_idx_train,
return_sample=False, return_sample=False,
forward_generator=False, forward_generator=False,
) )
assert loss_d.requires_grad is False disc_loss = (
(
disc_stft_real_adv_loss
+ disc_stft_fake_adv_loss
+ disc_period_real_adv_loss
+ disc_period_fake_adv_loss
+ disc_scale_real_adv_loss
+ disc_scale_fake_adv_loss
)
* d_weight
/ 3
)
assert disc_loss.requires_grad is False
for k, v in stats_d.items(): for k, v in stats_d.items():
loss_info[k] = v * batch_size loss_info[k] = v * batch_size
g_weight = adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
# forward generator # forward generator
loss_g, stats_g = model( (
commit_loss,
gen_stft_adv_loss,
gen_period_adv_loss,
gen_scale_adv_loss,
feature_stft_loss,
feature_period_loss,
feature_scale_loss,
wav_reconstruction_loss,
mel_reconstruction_loss,
stats_g,
) = model(
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
global_step=params.batch_idx_train,
forward_generator=True, forward_generator=True,
return_sample=False, return_sample=False,
) )
assert loss_g.requires_grad is False gen_adv_loss = (
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
* g_weight
/ 3
)
feature_loss = (
feature_stft_loss + feature_period_loss + feature_scale_loss
) / 3
reconstruction_loss = (
params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss
)
gen_loss = (
gen_adv_loss
+ params.lambda_rec * reconstruction_loss
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
+ params.lambda_com * commit_loss
)
assert gen_loss.requires_grad is False
for k, v in stats_g.items(): for k, v in stats_g.items():
if "returned_sample" not in k: if "returned_sample" not in k:
loss_info[k] = v * batch_size loss_info[k] = v * batch_size
@ -691,24 +821,74 @@ def scan_pessimistic_batches_for_oom(
try: try:
# for discriminator # for discriminator
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
loss_d, stats_d = model( (
disc_stft_real_adv_loss,
disc_stft_fake_adv_loss,
disc_period_real_adv_loss,
disc_period_fake_adv_loss,
disc_scale_real_adv_loss,
disc_scale_fake_adv_loss,
stats_d,
) = model(
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
global_step=params.batch_idx_train,
return_sample=False, return_sample=False,
forward_generator=False, forward_generator=False,
) )
loss_d = (
(
disc_stft_real_adv_loss
+ disc_stft_fake_adv_loss
+ disc_period_real_adv_loss
+ disc_period_fake_adv_loss
+ disc_scale_real_adv_loss
+ disc_scale_fake_adv_loss
)
* adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
/ 3
)
optimizer_d.zero_grad() optimizer_d.zero_grad()
loss_d.backward() loss_d.backward()
# for generator # for generator
with autocast(enabled=params.use_fp16): with autocast(enabled=params.use_fp16):
loss_g, stats_g = model( (
commit_loss,
gen_stft_adv_loss,
gen_period_adv_loss,
gen_scale_adv_loss,
feature_stft_loss,
feature_period_loss,
feature_scale_loss,
wav_reconstruction_loss,
mel_reconstruction_loss,
stats_g,
) = model(
speech=audio, speech=audio,
speech_lengths=audio_lens, speech_lengths=audio_lens,
forward_generator=True, forward_generator=True,
global_step=params.batch_idx_train,
return_sample=False, return_sample=False,
) )
loss_g = (
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
* adopt_weight(
params.lambda_adv,
params.batch_idx_train,
threshold=params.discriminator_iter_start,
)
/ 3
+ params.lambda_rec
* (
params.lambda_wav * wav_reconstruction_loss
+ mel_reconstruction_loss
)
+ params.lambda_feat
* (feature_stft_loss + feature_period_loss + feature_scale_loss)
+ params.lambda_com * commit_loss
)
optimizer_g.zero_grad() optimizer_g.zero_grad()
loss_g.backward() loss_g.backward()
except Exception as e: except Exception as e: