mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
refactored loss functions
This commit is contained in:
parent
1e65a976d0
commit
f9340cc5d7
@ -139,7 +139,7 @@ class LibriTTSCodecDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=8,
|
||||||
help="The number of training dataloader workers that "
|
help="The number of training dataloader workers that "
|
||||||
"collect the batches.",
|
"collect the batches.",
|
||||||
)
|
)
|
||||||
|
@ -4,7 +4,13 @@ from typing import List
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from loss import loss_dis, loss_g
|
from loss import (
|
||||||
|
DiscriminatorAdversarialLoss,
|
||||||
|
FeatureMatchLoss,
|
||||||
|
GeneratorAdversarialLoss,
|
||||||
|
MelSpectrogramReconstructionLoss,
|
||||||
|
WavReconstructionLoss,
|
||||||
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
@ -47,11 +53,23 @@ class Encodec(nn.Module):
|
|||||||
self.cache_generator_outputs = cache_generator_outputs
|
self.cache_generator_outputs = cache_generator_outputs
|
||||||
self._cache = None
|
self._cache = None
|
||||||
|
|
||||||
|
# construct loss functions
|
||||||
|
self.generator_adversarial_loss = GeneratorAdversarialLoss(
|
||||||
|
average_by_discriminators=True, loss_type="hinge"
|
||||||
|
)
|
||||||
|
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
|
||||||
|
average_by_discriminators=True, loss_type="hinge"
|
||||||
|
)
|
||||||
|
self.feature_match_loss = FeatureMatchLoss(average_by_layers=False)
|
||||||
|
self.wav_reconstruction_loss = WavReconstructionLoss()
|
||||||
|
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
|
||||||
|
sampling_rate=self.sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
def _forward_generator(
|
def _forward_generator(
|
||||||
self,
|
self,
|
||||||
speech: torch.Tensor,
|
speech: torch.Tensor,
|
||||||
speech_lengths: torch.Tensor,
|
speech_lengths: torch.Tensor,
|
||||||
global_step: int,
|
|
||||||
return_sample: bool = False,
|
return_sample: bool = False,
|
||||||
):
|
):
|
||||||
"""Perform generator forward.
|
"""Perform generator forward.
|
||||||
@ -59,7 +77,6 @@ class Encodec(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||||
speech_lengths (Tensor): Speech length tensor (B,).
|
speech_lengths (Tensor): Speech length tensor (B,).
|
||||||
global_step (int): Global step.
|
|
||||||
return_sample (bool): Return the generator output.
|
return_sample (bool): Return the generator output.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -107,33 +124,56 @@ class Encodec(nn.Module):
|
|||||||
|
|
||||||
# calculate losses
|
# calculate losses
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
|
gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat)
|
||||||
commit_loss,
|
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
|
||||||
speech,
|
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)
|
||||||
speech_hat,
|
|
||||||
fmap,
|
feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat)
|
||||||
fmap_hat,
|
feature_period_loss = self.feature_match_loss(
|
||||||
y,
|
feats=fmap_p, feats_hat=fmap_p_hat
|
||||||
y_hat,
|
)
|
||||||
global_step,
|
feature_scale_loss = self.feature_match_loss(
|
||||||
y_p,
|
feats=fmap_s, feats_hat=fmap_s_hat
|
||||||
y_p_hat,
|
|
||||||
y_s,
|
|
||||||
y_s_hat,
|
|
||||||
fmap_p,
|
|
||||||
fmap_p_hat,
|
|
||||||
fmap_s,
|
|
||||||
fmap_s_hat,
|
|
||||||
args=self.params,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
wav_reconstruction_loss = self.wav_reconstruction_loss(
|
||||||
|
x=speech, x_hat=speech_hat
|
||||||
|
)
|
||||||
|
mel_reconstruction_loss = self.mel_reconstruction_loss(
|
||||||
|
x=speech, x_hat=speech_hat
|
||||||
|
)
|
||||||
|
|
||||||
|
# loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
|
||||||
|
# commit_loss,
|
||||||
|
# speech,
|
||||||
|
# speech_hat,
|
||||||
|
# fmap,
|
||||||
|
# fmap_hat,
|
||||||
|
# y,
|
||||||
|
# y_hat,
|
||||||
|
# y_p,
|
||||||
|
# y_p_hat,
|
||||||
|
# y_s,
|
||||||
|
# y_s_hat,
|
||||||
|
# fmap_p,
|
||||||
|
# fmap_p_hat,
|
||||||
|
# fmap_s,
|
||||||
|
# fmap_s_hat,
|
||||||
|
# args=self.params,
|
||||||
|
# )
|
||||||
|
|
||||||
stats = dict(
|
stats = dict(
|
||||||
generator_loss=loss.item(),
|
# generator_loss=loss.item(),
|
||||||
generator_reconstruction_loss=rec_loss.item(),
|
generator_wav_reconstruction_loss=wav_reconstruction_loss.item(),
|
||||||
generator_feature_loss=feat_loss.item(),
|
generator_mel_reconstruction_loss=mel_reconstruction_loss.item(),
|
||||||
generator_adv_loss=adv_loss.item(),
|
generator_feature_stft_loss=feature_stft_loss.item(),
|
||||||
|
generator_feature_period_loss=feature_period_loss.item(),
|
||||||
|
generator_feature_scale_loss=feature_scale_loss.item(),
|
||||||
|
generator_stft_adv_loss=gen_stft_adv_loss.item(),
|
||||||
|
generator_period_adv_loss=gen_period_adv_loss.item(),
|
||||||
|
generator_scale_adv_loss=gen_scale_adv_loss.item(),
|
||||||
generator_commit_loss=commit_loss.item(),
|
generator_commit_loss=commit_loss.item(),
|
||||||
d_weight=d_weight.item(),
|
# d_weight=d_weight.item(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_sample:
|
if return_sample:
|
||||||
@ -147,19 +187,28 @@ class Encodec(nn.Module):
|
|||||||
# reset cache
|
# reset cache
|
||||||
if reuse_cache or not self.training:
|
if reuse_cache or not self.training:
|
||||||
self._cache = None
|
self._cache = None
|
||||||
return loss, stats
|
return (
|
||||||
|
commit_loss,
|
||||||
|
gen_stft_adv_loss,
|
||||||
|
gen_period_adv_loss,
|
||||||
|
gen_scale_adv_loss,
|
||||||
|
feature_stft_loss,
|
||||||
|
feature_period_loss,
|
||||||
|
feature_scale_loss,
|
||||||
|
wav_reconstruction_loss,
|
||||||
|
mel_reconstruction_loss,
|
||||||
|
stats,
|
||||||
|
)
|
||||||
|
|
||||||
def _forward_discriminator(
|
def _forward_discriminator(
|
||||||
self,
|
self,
|
||||||
speech: torch.Tensor,
|
speech: torch.Tensor,
|
||||||
speech_lengths: torch.Tensor,
|
speech_lengths: torch.Tensor,
|
||||||
global_step: int,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||||
speech_lengths (Tensor): Speech length tensor (B,).
|
speech_lengths (Tensor): Speech length tensor (B,).
|
||||||
global_step (int): Global step.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
* loss (Tensor): Loss scalar tensor.
|
* loss (Tensor): Loss scalar tensor.
|
||||||
@ -206,37 +255,46 @@ class Encodec(nn.Module):
|
|||||||
)
|
)
|
||||||
# calculate losses
|
# calculate losses
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
loss = loss_dis(
|
(
|
||||||
y,
|
disc_stft_real_adv_loss,
|
||||||
y_hat,
|
disc_stft_fake_adv_loss,
|
||||||
fmap,
|
) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat)
|
||||||
fmap_hat,
|
(
|
||||||
y_p,
|
disc_period_real_adv_loss,
|
||||||
y_p_hat,
|
disc_period_fake_adv_loss,
|
||||||
fmap_p,
|
) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat)
|
||||||
fmap_p_hat,
|
(
|
||||||
y_s,
|
disc_scale_real_adv_loss,
|
||||||
y_s_hat,
|
disc_scale_fake_adv_loss,
|
||||||
fmap_s,
|
) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat)
|
||||||
fmap_s_hat,
|
|
||||||
global_step,
|
|
||||||
args=self.params,
|
|
||||||
)
|
|
||||||
stats = dict(
|
stats = dict(
|
||||||
discriminator_loss=loss.item(),
|
discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(),
|
||||||
|
discriminator_period_real_adv_loss=disc_period_real_adv_loss.item(),
|
||||||
|
discriminator_scale_real_adv_loss=disc_scale_real_adv_loss.item(),
|
||||||
|
discriminator_stft_fake_adv_loss=disc_stft_fake_adv_loss.item(),
|
||||||
|
discriminator_period_fake_adv_loss=disc_period_fake_adv_loss.item(),
|
||||||
|
discriminator_scale_fake_adv_loss=disc_scale_fake_adv_loss.item(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# reset cache
|
# reset cache
|
||||||
if reuse_cache or not self.training:
|
if reuse_cache or not self.training:
|
||||||
self._cache = None
|
self._cache = None
|
||||||
|
|
||||||
return loss, stats
|
return (
|
||||||
|
disc_stft_real_adv_loss,
|
||||||
|
disc_stft_fake_adv_loss,
|
||||||
|
disc_period_real_adv_loss,
|
||||||
|
disc_period_fake_adv_loss,
|
||||||
|
disc_scale_real_adv_loss,
|
||||||
|
disc_scale_fake_adv_loss,
|
||||||
|
stats,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
speech: torch.Tensor,
|
speech: torch.Tensor,
|
||||||
speech_lengths: torch.Tensor,
|
speech_lengths: torch.Tensor,
|
||||||
global_step: int,
|
|
||||||
return_sample: bool,
|
return_sample: bool,
|
||||||
forward_generator: bool,
|
forward_generator: bool,
|
||||||
):
|
):
|
||||||
@ -244,14 +302,12 @@ class Encodec(nn.Module):
|
|||||||
return self._forward_generator(
|
return self._forward_generator(
|
||||||
speech=speech,
|
speech=speech,
|
||||||
speech_lengths=speech_lengths,
|
speech_lengths=speech_lengths,
|
||||||
global_step=global_step,
|
|
||||||
return_sample=return_sample,
|
return_sample=return_sample,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._forward_discriminator(
|
return self._forward_discriminator(
|
||||||
speech=speech,
|
speech=speech,
|
||||||
speech_lengths=speech_lengths,
|
speech_lengths=speech_lengths,
|
||||||
global_step=global_step,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def encode(self, x, target_bw=None, st=None):
|
def encode(self, x, target_bw=None, st=None):
|
||||||
|
@ -71,7 +71,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--target-bw",
|
"--target-bw",
|
||||||
type=float,
|
type=float,
|
||||||
default=7.5,
|
default=24,
|
||||||
help="The target bandwidth for the generator",
|
help="The target bandwidth for the generator",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,8 +1,310 @@
|
|||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from lhotse.features.kaldi import Wav2LogFilterBank
|
||||||
from torchaudio.transforms import MelSpectrogram
|
from torchaudio.transforms import MelSpectrogram
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorAdversarialLoss(torch.nn.Module):
|
||||||
|
"""Generator adversarial loss module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
average_by_discriminators: bool = True,
|
||||||
|
loss_type: str = "hinge",
|
||||||
|
):
|
||||||
|
"""Initialize GeneratorAversarialLoss module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
average_by_discriminators (bool): Whether to average the loss by
|
||||||
|
the number of discriminators.
|
||||||
|
loss_type (str): Loss type, "mse" or "hinge".
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.average_by_discriminators = average_by_discriminators
|
||||||
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
||||||
|
if loss_type == "mse":
|
||||||
|
self.criterion = self._mse_loss
|
||||||
|
else:
|
||||||
|
self.criterion = self._hinge_loss
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calcualate generator adversarial loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
||||||
|
outputs, list of discriminator outputs, or list of list of discriminator
|
||||||
|
outputs..
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Generator adversarial loss value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(outputs, (tuple, list)):
|
||||||
|
adv_loss = 0.0
|
||||||
|
for i, outputs_ in enumerate(outputs):
|
||||||
|
if isinstance(outputs_, (tuple, list)):
|
||||||
|
# NOTE(kan-bayashi): case including feature maps
|
||||||
|
outputs_ = outputs_[-1]
|
||||||
|
adv_loss += self.criterion(outputs_)
|
||||||
|
if self.average_by_discriminators:
|
||||||
|
adv_loss /= i + 1
|
||||||
|
else:
|
||||||
|
adv_loss = self.criterion(outputs)
|
||||||
|
|
||||||
|
return adv_loss
|
||||||
|
|
||||||
|
def _mse_loss(self, x):
|
||||||
|
return F.mse_loss(x, x.new_ones(x.size()))
|
||||||
|
|
||||||
|
def _hinge_loss(self, x):
|
||||||
|
return F.relu(1 - x).mean()
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorAdversarialLoss(torch.nn.Module):
|
||||||
|
"""Discriminator adversarial loss module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
average_by_discriminators: bool = True,
|
||||||
|
loss_type: str = "hinge",
|
||||||
|
):
|
||||||
|
"""Initialize DiscriminatorAversarialLoss module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
average_by_discriminators (bool): Whether to average the loss by
|
||||||
|
the number of discriminators.
|
||||||
|
loss_type (str): Loss type, "mse" or "hinge".
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.average_by_discriminators = average_by_discriminators
|
||||||
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
||||||
|
if loss_type == "mse":
|
||||||
|
self.fake_criterion = self._mse_fake_loss
|
||||||
|
self.real_criterion = self._mse_real_loss
|
||||||
|
else:
|
||||||
|
self.fake_criterion = self._hinge_fake_loss
|
||||||
|
self.real_criterion = self._hinge_real_loss
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
||||||
|
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Calcualate discriminator adversarial loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
||||||
|
outputs, list of discriminator outputs, or list of list of discriminator
|
||||||
|
outputs calculated from generator.
|
||||||
|
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
||||||
|
outputs, list of discriminator outputs, or list of list of discriminator
|
||||||
|
outputs calculated from groundtruth.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Discriminator real loss value.
|
||||||
|
Tensor: Discriminator fake loss value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(outputs, (tuple, list)):
|
||||||
|
real_loss = 0.0
|
||||||
|
fake_loss = 0.0
|
||||||
|
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
||||||
|
if isinstance(outputs_hat_, (tuple, list)):
|
||||||
|
# NOTE(kan-bayashi): case including feature maps
|
||||||
|
outputs_hat_ = outputs_hat_[-1]
|
||||||
|
outputs_ = outputs_[-1]
|
||||||
|
real_loss += self.real_criterion(outputs_)
|
||||||
|
fake_loss += self.fake_criterion(outputs_hat_)
|
||||||
|
if self.average_by_discriminators:
|
||||||
|
fake_loss /= i + 1
|
||||||
|
real_loss /= i + 1
|
||||||
|
else:
|
||||||
|
real_loss = self.real_criterion(outputs)
|
||||||
|
fake_loss = self.fake_criterion(outputs_hat)
|
||||||
|
|
||||||
|
return real_loss, fake_loss
|
||||||
|
|
||||||
|
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.mse_loss(x, x.new_ones(x.size()))
|
||||||
|
|
||||||
|
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.mse_loss(x, x.new_zeros(x.size()))
|
||||||
|
|
||||||
|
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.relu(x.new_ones(x.size()) - x).mean()
|
||||||
|
|
||||||
|
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.relu(x.new_ones(x.size()) + x).mean()
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMatchLoss(torch.nn.Module):
|
||||||
|
"""Feature matching loss module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
average_by_layers: bool = True,
|
||||||
|
average_by_discriminators: bool = True,
|
||||||
|
include_final_outputs: bool = False,
|
||||||
|
):
|
||||||
|
"""Initialize FeatureMatchLoss module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
average_by_layers (bool): Whether to average the loss by the number
|
||||||
|
of layers.
|
||||||
|
average_by_discriminators (bool): Whether to average the loss by
|
||||||
|
the number of discriminators.
|
||||||
|
include_final_outputs (bool): Whether to include the final output of
|
||||||
|
each discriminator for loss calculation.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.average_by_layers = average_by_layers
|
||||||
|
self.average_by_discriminators = average_by_discriminators
|
||||||
|
self.include_final_outputs = include_final_outputs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
||||||
|
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate feature matching loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
||||||
|
discriminator outputs or list of discriminator outputs calcuated
|
||||||
|
from generator's outputs.
|
||||||
|
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
||||||
|
discriminator outputs or list of discriminator outputs calcuated
|
||||||
|
from groundtruth..
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Feature matching loss value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
feat_match_loss = 0.0
|
||||||
|
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
|
||||||
|
feat_match_loss_ = 0.0
|
||||||
|
if not self.include_final_outputs:
|
||||||
|
feats_hat_ = feats_hat_[:-1]
|
||||||
|
feats_ = feats_[:-1]
|
||||||
|
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
||||||
|
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
|
||||||
|
if self.average_by_layers:
|
||||||
|
feat_match_loss_ /= j + 1
|
||||||
|
feat_match_loss += feat_match_loss_
|
||||||
|
if self.average_by_discriminators:
|
||||||
|
feat_match_loss /= i + 1
|
||||||
|
|
||||||
|
return feat_match_loss
|
||||||
|
|
||||||
|
|
||||||
|
class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
||||||
|
"""Mel Spec Reconstruction loss."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sampling_rate: int = 22050,
|
||||||
|
n_mels: int = 64,
|
||||||
|
use_fft_mag: bool = True,
|
||||||
|
return_mel: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.wav_to_specs = []
|
||||||
|
for i in range(5, 12):
|
||||||
|
s = 2**i
|
||||||
|
# self.wav_to_specs.append(
|
||||||
|
# Wav2LogFilterBank(
|
||||||
|
# sampling_rate=sampling_rate,
|
||||||
|
# frame_length=s,
|
||||||
|
# frame_shift=s // 4,
|
||||||
|
# use_fft_mag=use_fft_mag,
|
||||||
|
# num_filters=n_mels,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
self.wav_to_specs.append(
|
||||||
|
MelSpectrogram(
|
||||||
|
sample_rate=sampling_rate,
|
||||||
|
n_fft=max(s, 512),
|
||||||
|
win_length=s,
|
||||||
|
hop_length=s // 4,
|
||||||
|
n_mels=n_mels,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.return_mel = return_mel
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_hat: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
"""Calculate Mel-spectrogram loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_hat (Tensor): Generated waveform tensor (B, 1, T).
|
||||||
|
x (Tensor): Groundtruth waveform tensor (B, 1, T).
|
||||||
|
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
|
||||||
|
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
|
||||||
|
waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Mel-spectrogram loss value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
mel_loss = 0.0
|
||||||
|
|
||||||
|
for i, wav_to_spec in enumerate(self.wav_to_specs):
|
||||||
|
s = 2 ** (i + 5)
|
||||||
|
wav_to_spec.to(x.device)
|
||||||
|
|
||||||
|
mel_hat = wav_to_spec(x_hat.squeeze(1))
|
||||||
|
mel = wav_to_spec(x.squeeze(1))
|
||||||
|
|
||||||
|
alpha = (s / 2) ** 0.5
|
||||||
|
mel_loss += F.l1_loss(mel_hat, mel) + alpha * F.mse_loss(mel_hat, mel)
|
||||||
|
|
||||||
|
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
|
||||||
|
# mel = self.wav_to_spec(x.squeeze(1))
|
||||||
|
# mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel)
|
||||||
|
|
||||||
|
if self.return_mel:
|
||||||
|
return mel_loss, (mel_hat, mel)
|
||||||
|
|
||||||
|
return mel_loss
|
||||||
|
|
||||||
|
|
||||||
|
class WavReconstructionLoss(torch.nn.Module):
|
||||||
|
"""Wav Reconstruction loss."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x_hat: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate wav loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_hat (Tensor): Generated waveform tensor (B, 1, T).
|
||||||
|
x (Tensor): Groundtruth waveform tensor (B, 1, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Wav loss value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
wav_loss = F.mse_loss(x, x_hat)
|
||||||
|
|
||||||
|
return wav_loss
|
||||||
|
|
||||||
|
|
||||||
def adversarial_g_loss(y_disc_gen):
|
def adversarial_g_loss(y_disc_gen):
|
||||||
"""Hinge loss"""
|
"""Hinge loss"""
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
@ -63,88 +365,12 @@ def reconstruction_loss(x, x_hat, args, eps=1e-7):
|
|||||||
return L
|
return L
|
||||||
|
|
||||||
|
|
||||||
def criterion_d(
|
|
||||||
y_disc_r,
|
|
||||||
y_disc_gen,
|
|
||||||
fmap_r_det,
|
|
||||||
fmap_gen_det,
|
|
||||||
y_df_hat_r,
|
|
||||||
y_df_hat_g,
|
|
||||||
fmap_f_r,
|
|
||||||
fmap_f_g,
|
|
||||||
y_ds_hat_r,
|
|
||||||
y_ds_hat_g,
|
|
||||||
fmap_s_r,
|
|
||||||
fmap_s_g,
|
|
||||||
):
|
|
||||||
"""Hinge Loss"""
|
|
||||||
loss = 0.0
|
|
||||||
loss1 = 0.0
|
|
||||||
loss2 = 0.0
|
|
||||||
loss3 = 0.0
|
|
||||||
for i in range(len(y_disc_r)):
|
|
||||||
loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[i]).mean()
|
|
||||||
for i in range(len(y_df_hat_r)):
|
|
||||||
loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[i]).mean()
|
|
||||||
for i in range(len(y_ds_hat_r)):
|
|
||||||
loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[i]).mean()
|
|
||||||
|
|
||||||
loss = (
|
|
||||||
loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / len(y_ds_hat_r)
|
|
||||||
) / 3.0
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
def criterion_g(
|
|
||||||
commit_loss,
|
|
||||||
x,
|
|
||||||
G_x,
|
|
||||||
fmap_r,
|
|
||||||
fmap_gen,
|
|
||||||
y_disc_r,
|
|
||||||
y_disc_gen,
|
|
||||||
y_df_hat_r,
|
|
||||||
y_df_hat_g,
|
|
||||||
fmap_f_r,
|
|
||||||
fmap_f_g,
|
|
||||||
y_ds_hat_r,
|
|
||||||
y_ds_hat_g,
|
|
||||||
fmap_s_r,
|
|
||||||
fmap_s_g,
|
|
||||||
args,
|
|
||||||
):
|
|
||||||
adv_g_loss = adversarial_g_loss(y_disc_gen)
|
|
||||||
feat_loss = (
|
|
||||||
feature_loss(fmap_r, fmap_gen)
|
|
||||||
+ sim_loss(y_disc_r, y_disc_gen)
|
|
||||||
+ feature_loss(fmap_f_r, fmap_f_g)
|
|
||||||
+ sim_loss(y_df_hat_r, y_df_hat_g)
|
|
||||||
+ feature_loss(fmap_s_r, fmap_s_g)
|
|
||||||
+ sim_loss(y_ds_hat_r, y_ds_hat_g)
|
|
||||||
) / 3.0
|
|
||||||
rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args)
|
|
||||||
total_loss = (
|
|
||||||
args.lambda_com * commit_loss
|
|
||||||
+ args.lambda_adv * adv_g_loss
|
|
||||||
+ args.lambda_feat * feat_loss
|
|
||||||
+ args.lambda_rec * rec_loss
|
|
||||||
)
|
|
||||||
return total_loss, adv_g_loss, feat_loss, rec_loss
|
|
||||||
|
|
||||||
|
|
||||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
||||||
if global_step < threshold:
|
if global_step < threshold:
|
||||||
weight = value
|
weight = value
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|
||||||
def adopt_dis_weight(weight, global_step, threshold=0, value=0.0):
|
|
||||||
if global_step % 3 == 0:
|
|
||||||
weight = value
|
|
||||||
return weight
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
|
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
|
||||||
if last_layer is not None:
|
if last_layer is not None:
|
||||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||||
@ -166,7 +392,6 @@ def loss_g(
|
|||||||
fmap_hat,
|
fmap_hat,
|
||||||
y,
|
y,
|
||||||
y_hat,
|
y_hat,
|
||||||
global_step,
|
|
||||||
y_df,
|
y_df,
|
||||||
y_df_hat,
|
y_df_hat,
|
||||||
y_ds,
|
y_ds,
|
||||||
@ -215,9 +440,10 @@ def loss_g(
|
|||||||
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
|
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
|
||||||
d_weight = torch.tensor(1.0)
|
d_weight = torch.tensor(1.0)
|
||||||
|
|
||||||
disc_factor = adopt_weight(
|
# disc_factor = adopt_weight(
|
||||||
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
|
# args.lambda_adv, global_step, threshold=args.discriminator_iter_start
|
||||||
)
|
# )
|
||||||
|
disc_factor = 1
|
||||||
if disc_factor == 0.0:
|
if disc_factor == 0.0:
|
||||||
fm_loss_wt = 0
|
fm_loss_wt = 0
|
||||||
else:
|
else:
|
||||||
@ -232,37 +458,9 @@ def loss_g(
|
|||||||
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
|
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
|
||||||
|
|
||||||
|
|
||||||
def loss_dis(
|
if __name__ == "__main__":
|
||||||
y,
|
la = FeatureMatchLoss(average_by_layers=False, average_by_discriminators=False)
|
||||||
y_hat,
|
aa = [torch.rand(192, 192) for _ in range(3)]
|
||||||
fmap,
|
bb = [torch.rand(192, 192) for _ in range(3)]
|
||||||
fmap_hat,
|
print(la(bb, aa))
|
||||||
y_df,
|
print(feature_loss(aa, bb))
|
||||||
y_df_hat,
|
|
||||||
fmap_f,
|
|
||||||
fmap_f_hat,
|
|
||||||
y_ds,
|
|
||||||
y_ds_hat,
|
|
||||||
fmap_s,
|
|
||||||
fmap_s_hat,
|
|
||||||
global_step,
|
|
||||||
args,
|
|
||||||
):
|
|
||||||
disc_factor = adopt_weight(
|
|
||||||
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
|
|
||||||
)
|
|
||||||
d_loss = disc_factor * criterion_d(
|
|
||||||
y,
|
|
||||||
y_hat,
|
|
||||||
fmap,
|
|
||||||
fmap_hat,
|
|
||||||
y_df,
|
|
||||||
y_df_hat,
|
|
||||||
fmap_f,
|
|
||||||
fmap_f_hat,
|
|
||||||
y_ds,
|
|
||||||
y_ds_hat,
|
|
||||||
fmap_s,
|
|
||||||
fmap_s_hat,
|
|
||||||
)
|
|
||||||
return d_loss
|
|
||||||
|
@ -15,12 +15,13 @@ from codec_datamodule import LibriTTSCodecDataModule
|
|||||||
from encodec import Encodec
|
from encodec import Encodec
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
from loss import adopt_weight
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import GradScaler, autocast
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
from utils import MetricsTracker, save_checkpoint
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
@ -250,11 +251,26 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
from modules.seanet import SEANetDecoder, SEANetEncoder
|
from modules.seanet import SEANetDecoder, SEANetEncoder
|
||||||
from quantization import ResidualVectorQuantizer
|
from quantization import ResidualVectorQuantizer
|
||||||
|
|
||||||
|
# generator_params = {
|
||||||
|
# "generator_n_filters": 32,
|
||||||
|
# "dimension": 512,
|
||||||
|
# "ratios": [2, 2, 2, 4],
|
||||||
|
# "target_bandwidths": [7.5, 15],
|
||||||
|
# "bins": 1024,
|
||||||
|
# }
|
||||||
|
# discriminator_params = {
|
||||||
|
# "stft_discriminator_n_filters": 32,
|
||||||
|
# "discriminator_iter_start": 500,
|
||||||
|
# }
|
||||||
|
# inference_params = {
|
||||||
|
# "target_bw": 7.5,
|
||||||
|
# }
|
||||||
|
|
||||||
generator_params = {
|
generator_params = {
|
||||||
"generator_n_filters": 32,
|
"generator_n_filters": 32,
|
||||||
"dimension": 512,
|
"dimension": 512,
|
||||||
"ratios": [2, 2, 2, 4],
|
"ratios": [8, 5, 4, 2],
|
||||||
"target_bandwidths": [7.5, 15],
|
"target_bandwidths": [1.5, 3, 6, 12, 24],
|
||||||
"bins": 1024,
|
"bins": 1024,
|
||||||
}
|
}
|
||||||
discriminator_params = {
|
discriminator_params = {
|
||||||
@ -262,7 +278,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
"discriminator_iter_start": 500,
|
"discriminator_iter_start": 500,
|
||||||
}
|
}
|
||||||
inference_params = {
|
inference_params = {
|
||||||
"target_bw": 7.5,
|
"target_bw": 12,
|
||||||
}
|
}
|
||||||
|
|
||||||
params.update(generator_params)
|
params.update(generator_params)
|
||||||
@ -419,36 +435,93 @@ def train_one_epoch(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
|
d_weight = adopt_weight(
|
||||||
|
params.lambda_adv,
|
||||||
|
params.batch_idx_train,
|
||||||
|
threshold=params.discriminator_iter_start,
|
||||||
|
)
|
||||||
# forward discriminator
|
# forward discriminator
|
||||||
loss_d, stats_d = model(
|
(
|
||||||
|
disc_stft_real_adv_loss,
|
||||||
|
disc_stft_fake_adv_loss,
|
||||||
|
disc_period_real_adv_loss,
|
||||||
|
disc_period_fake_adv_loss,
|
||||||
|
disc_scale_real_adv_loss,
|
||||||
|
disc_scale_fake_adv_loss,
|
||||||
|
stats_d,
|
||||||
|
) = model(
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
global_step=params.batch_idx_train,
|
|
||||||
return_sample=False,
|
return_sample=False,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
|
disc_loss = (
|
||||||
|
(
|
||||||
|
disc_stft_real_adv_loss
|
||||||
|
+ disc_stft_fake_adv_loss
|
||||||
|
+ disc_period_real_adv_loss
|
||||||
|
+ disc_period_fake_adv_loss
|
||||||
|
+ disc_scale_real_adv_loss
|
||||||
|
+ disc_scale_fake_adv_loss
|
||||||
|
)
|
||||||
|
* d_weight
|
||||||
|
/ 3
|
||||||
|
)
|
||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
# update discriminator
|
# update discriminator
|
||||||
optimizer_d.zero_grad()
|
optimizer_d.zero_grad()
|
||||||
scaler.scale(loss_d).backward()
|
scaler.scale(disc_loss).backward()
|
||||||
scaler.step(optimizer_d)
|
scaler.step(optimizer_d)
|
||||||
|
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
|
g_weight = adopt_weight(
|
||||||
|
params.lambda_adv,
|
||||||
|
params.batch_idx_train,
|
||||||
|
threshold=params.discriminator_iter_start,
|
||||||
|
)
|
||||||
# forward generator
|
# forward generator
|
||||||
loss_g, stats_g = model(
|
(
|
||||||
|
commit_loss,
|
||||||
|
gen_stft_adv_loss,
|
||||||
|
gen_period_adv_loss,
|
||||||
|
gen_scale_adv_loss,
|
||||||
|
feature_stft_loss,
|
||||||
|
feature_period_loss,
|
||||||
|
feature_scale_loss,
|
||||||
|
wav_reconstruction_loss,
|
||||||
|
mel_reconstruction_loss,
|
||||||
|
stats_g,
|
||||||
|
) = model(
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
global_step=params.batch_idx_train,
|
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||||
)
|
)
|
||||||
|
gen_adv_loss = (
|
||||||
|
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
||||||
|
* g_weight
|
||||||
|
/ 3
|
||||||
|
)
|
||||||
|
feature_loss = (
|
||||||
|
feature_stft_loss + feature_period_loss + feature_scale_loss
|
||||||
|
) / 3
|
||||||
|
reconstruction_loss = (
|
||||||
|
params.lambda_wav * wav_reconstruction_loss
|
||||||
|
+ mel_reconstruction_loss
|
||||||
|
)
|
||||||
|
gen_loss = (
|
||||||
|
gen_adv_loss
|
||||||
|
+ params.lambda_rec * reconstruction_loss
|
||||||
|
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
|
||||||
|
+ params.lambda_com * commit_loss
|
||||||
|
)
|
||||||
for k, v in stats_g.items():
|
for k, v in stats_g.items():
|
||||||
if "returned_sample" not in k:
|
if "returned_sample" not in k:
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
# update generator
|
# update generator
|
||||||
optimizer_g.zero_grad()
|
optimizer_g.zero_grad()
|
||||||
scaler.scale(loss_g).backward()
|
scaler.scale(gen_loss).backward()
|
||||||
scaler.step(optimizer_g)
|
scaler.step(optimizer_g)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
|
||||||
@ -619,27 +692,84 @@ def compute_validation_loss(
|
|||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
|
d_weight = adopt_weight(
|
||||||
|
params.lambda_adv,
|
||||||
|
params.batch_idx_train,
|
||||||
|
threshold=params.discriminator_iter_start,
|
||||||
|
)
|
||||||
|
|
||||||
# forward discriminator
|
# forward discriminator
|
||||||
loss_d, stats_d = model(
|
(
|
||||||
|
disc_stft_real_adv_loss,
|
||||||
|
disc_stft_fake_adv_loss,
|
||||||
|
disc_period_real_adv_loss,
|
||||||
|
disc_period_fake_adv_loss,
|
||||||
|
disc_scale_real_adv_loss,
|
||||||
|
disc_scale_fake_adv_loss,
|
||||||
|
stats_d,
|
||||||
|
) = model(
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
global_step=params.batch_idx_train,
|
|
||||||
return_sample=False,
|
return_sample=False,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
assert loss_d.requires_grad is False
|
disc_loss = (
|
||||||
|
(
|
||||||
|
disc_stft_real_adv_loss
|
||||||
|
+ disc_stft_fake_adv_loss
|
||||||
|
+ disc_period_real_adv_loss
|
||||||
|
+ disc_period_fake_adv_loss
|
||||||
|
+ disc_scale_real_adv_loss
|
||||||
|
+ disc_scale_fake_adv_loss
|
||||||
|
)
|
||||||
|
* d_weight
|
||||||
|
/ 3
|
||||||
|
)
|
||||||
|
assert disc_loss.requires_grad is False
|
||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
|
|
||||||
|
g_weight = adopt_weight(
|
||||||
|
params.lambda_adv,
|
||||||
|
params.batch_idx_train,
|
||||||
|
threshold=params.discriminator_iter_start,
|
||||||
|
)
|
||||||
# forward generator
|
# forward generator
|
||||||
loss_g, stats_g = model(
|
(
|
||||||
|
commit_loss,
|
||||||
|
gen_stft_adv_loss,
|
||||||
|
gen_period_adv_loss,
|
||||||
|
gen_scale_adv_loss,
|
||||||
|
feature_stft_loss,
|
||||||
|
feature_period_loss,
|
||||||
|
feature_scale_loss,
|
||||||
|
wav_reconstruction_loss,
|
||||||
|
mel_reconstruction_loss,
|
||||||
|
stats_g,
|
||||||
|
) = model(
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
global_step=params.batch_idx_train,
|
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
return_sample=False,
|
return_sample=False,
|
||||||
)
|
)
|
||||||
assert loss_g.requires_grad is False
|
gen_adv_loss = (
|
||||||
|
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
||||||
|
* g_weight
|
||||||
|
/ 3
|
||||||
|
)
|
||||||
|
feature_loss = (
|
||||||
|
feature_stft_loss + feature_period_loss + feature_scale_loss
|
||||||
|
) / 3
|
||||||
|
reconstruction_loss = (
|
||||||
|
params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss
|
||||||
|
)
|
||||||
|
gen_loss = (
|
||||||
|
gen_adv_loss
|
||||||
|
+ params.lambda_rec * reconstruction_loss
|
||||||
|
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
|
||||||
|
+ params.lambda_com * commit_loss
|
||||||
|
)
|
||||||
|
assert gen_loss.requires_grad is False
|
||||||
for k, v in stats_g.items():
|
for k, v in stats_g.items():
|
||||||
if "returned_sample" not in k:
|
if "returned_sample" not in k:
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
@ -691,24 +821,74 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
try:
|
try:
|
||||||
# for discriminator
|
# for discriminator
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
loss_d, stats_d = model(
|
(
|
||||||
|
disc_stft_real_adv_loss,
|
||||||
|
disc_stft_fake_adv_loss,
|
||||||
|
disc_period_real_adv_loss,
|
||||||
|
disc_period_fake_adv_loss,
|
||||||
|
disc_scale_real_adv_loss,
|
||||||
|
disc_scale_fake_adv_loss,
|
||||||
|
stats_d,
|
||||||
|
) = model(
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
global_step=params.batch_idx_train,
|
|
||||||
return_sample=False,
|
return_sample=False,
|
||||||
forward_generator=False,
|
forward_generator=False,
|
||||||
)
|
)
|
||||||
|
loss_d = (
|
||||||
|
(
|
||||||
|
disc_stft_real_adv_loss
|
||||||
|
+ disc_stft_fake_adv_loss
|
||||||
|
+ disc_period_real_adv_loss
|
||||||
|
+ disc_period_fake_adv_loss
|
||||||
|
+ disc_scale_real_adv_loss
|
||||||
|
+ disc_scale_fake_adv_loss
|
||||||
|
)
|
||||||
|
* adopt_weight(
|
||||||
|
params.lambda_adv,
|
||||||
|
params.batch_idx_train,
|
||||||
|
threshold=params.discriminator_iter_start,
|
||||||
|
)
|
||||||
|
/ 3
|
||||||
|
)
|
||||||
optimizer_d.zero_grad()
|
optimizer_d.zero_grad()
|
||||||
loss_d.backward()
|
loss_d.backward()
|
||||||
# for generator
|
# for generator
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
loss_g, stats_g = model(
|
(
|
||||||
|
commit_loss,
|
||||||
|
gen_stft_adv_loss,
|
||||||
|
gen_period_adv_loss,
|
||||||
|
gen_scale_adv_loss,
|
||||||
|
feature_stft_loss,
|
||||||
|
feature_period_loss,
|
||||||
|
feature_scale_loss,
|
||||||
|
wav_reconstruction_loss,
|
||||||
|
mel_reconstruction_loss,
|
||||||
|
stats_g,
|
||||||
|
) = model(
|
||||||
speech=audio,
|
speech=audio,
|
||||||
speech_lengths=audio_lens,
|
speech_lengths=audio_lens,
|
||||||
forward_generator=True,
|
forward_generator=True,
|
||||||
global_step=params.batch_idx_train,
|
|
||||||
return_sample=False,
|
return_sample=False,
|
||||||
)
|
)
|
||||||
|
loss_g = (
|
||||||
|
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
||||||
|
* adopt_weight(
|
||||||
|
params.lambda_adv,
|
||||||
|
params.batch_idx_train,
|
||||||
|
threshold=params.discriminator_iter_start,
|
||||||
|
)
|
||||||
|
/ 3
|
||||||
|
+ params.lambda_rec
|
||||||
|
* (
|
||||||
|
params.lambda_wav * wav_reconstruction_loss
|
||||||
|
+ mel_reconstruction_loss
|
||||||
|
)
|
||||||
|
+ params.lambda_feat
|
||||||
|
* (feature_stft_loss + feature_period_loss + feature_scale_loss)
|
||||||
|
+ params.lambda_com * commit_loss
|
||||||
|
)
|
||||||
optimizer_g.zero_grad()
|
optimizer_g.zero_grad()
|
||||||
loss_g.backward()
|
loss_g.backward()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user