mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 02:06:13 +00:00
refactored loss functions
This commit is contained in:
parent
1e65a976d0
commit
f9340cc5d7
@ -139,7 +139,7 @@ class LibriTTSCodecDataModule:
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
default=8,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
@ -4,7 +4,13 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loss import loss_dis, loss_g
|
||||
from loss import (
|
||||
DiscriminatorAdversarialLoss,
|
||||
FeatureMatchLoss,
|
||||
GeneratorAdversarialLoss,
|
||||
MelSpectrogramReconstructionLoss,
|
||||
WavReconstructionLoss,
|
||||
)
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
@ -47,11 +53,23 @@ class Encodec(nn.Module):
|
||||
self.cache_generator_outputs = cache_generator_outputs
|
||||
self._cache = None
|
||||
|
||||
# construct loss functions
|
||||
self.generator_adversarial_loss = GeneratorAdversarialLoss(
|
||||
average_by_discriminators=True, loss_type="hinge"
|
||||
)
|
||||
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
|
||||
average_by_discriminators=True, loss_type="hinge"
|
||||
)
|
||||
self.feature_match_loss = FeatureMatchLoss(average_by_layers=False)
|
||||
self.wav_reconstruction_loss = WavReconstructionLoss()
|
||||
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
|
||||
sampling_rate=self.sampling_rate
|
||||
)
|
||||
|
||||
def _forward_generator(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
global_step: int,
|
||||
return_sample: bool = False,
|
||||
):
|
||||
"""Perform generator forward.
|
||||
@ -59,7 +77,6 @@ class Encodec(nn.Module):
|
||||
Args:
|
||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||
speech_lengths (Tensor): Speech length tensor (B,).
|
||||
global_step (int): Global step.
|
||||
return_sample (bool): Return the generator output.
|
||||
|
||||
Returns:
|
||||
@ -107,33 +124,56 @@ class Encodec(nn.Module):
|
||||
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
|
||||
commit_loss,
|
||||
speech,
|
||||
speech_hat,
|
||||
fmap,
|
||||
fmap_hat,
|
||||
y,
|
||||
y_hat,
|
||||
global_step,
|
||||
y_p,
|
||||
y_p_hat,
|
||||
y_s,
|
||||
y_s_hat,
|
||||
fmap_p,
|
||||
fmap_p_hat,
|
||||
fmap_s,
|
||||
fmap_s_hat,
|
||||
args=self.params,
|
||||
gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat)
|
||||
gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat)
|
||||
gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat)
|
||||
|
||||
feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat)
|
||||
feature_period_loss = self.feature_match_loss(
|
||||
feats=fmap_p, feats_hat=fmap_p_hat
|
||||
)
|
||||
feature_scale_loss = self.feature_match_loss(
|
||||
feats=fmap_s, feats_hat=fmap_s_hat
|
||||
)
|
||||
|
||||
wav_reconstruction_loss = self.wav_reconstruction_loss(
|
||||
x=speech, x_hat=speech_hat
|
||||
)
|
||||
mel_reconstruction_loss = self.mel_reconstruction_loss(
|
||||
x=speech, x_hat=speech_hat
|
||||
)
|
||||
|
||||
# loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
|
||||
# commit_loss,
|
||||
# speech,
|
||||
# speech_hat,
|
||||
# fmap,
|
||||
# fmap_hat,
|
||||
# y,
|
||||
# y_hat,
|
||||
# y_p,
|
||||
# y_p_hat,
|
||||
# y_s,
|
||||
# y_s_hat,
|
||||
# fmap_p,
|
||||
# fmap_p_hat,
|
||||
# fmap_s,
|
||||
# fmap_s_hat,
|
||||
# args=self.params,
|
||||
# )
|
||||
|
||||
stats = dict(
|
||||
generator_loss=loss.item(),
|
||||
generator_reconstruction_loss=rec_loss.item(),
|
||||
generator_feature_loss=feat_loss.item(),
|
||||
generator_adv_loss=adv_loss.item(),
|
||||
# generator_loss=loss.item(),
|
||||
generator_wav_reconstruction_loss=wav_reconstruction_loss.item(),
|
||||
generator_mel_reconstruction_loss=mel_reconstruction_loss.item(),
|
||||
generator_feature_stft_loss=feature_stft_loss.item(),
|
||||
generator_feature_period_loss=feature_period_loss.item(),
|
||||
generator_feature_scale_loss=feature_scale_loss.item(),
|
||||
generator_stft_adv_loss=gen_stft_adv_loss.item(),
|
||||
generator_period_adv_loss=gen_period_adv_loss.item(),
|
||||
generator_scale_adv_loss=gen_scale_adv_loss.item(),
|
||||
generator_commit_loss=commit_loss.item(),
|
||||
d_weight=d_weight.item(),
|
||||
# d_weight=d_weight.item(),
|
||||
)
|
||||
|
||||
if return_sample:
|
||||
@ -147,19 +187,28 @@ class Encodec(nn.Module):
|
||||
# reset cache
|
||||
if reuse_cache or not self.training:
|
||||
self._cache = None
|
||||
return loss, stats
|
||||
return (
|
||||
commit_loss,
|
||||
gen_stft_adv_loss,
|
||||
gen_period_adv_loss,
|
||||
gen_scale_adv_loss,
|
||||
feature_stft_loss,
|
||||
feature_period_loss,
|
||||
feature_scale_loss,
|
||||
wav_reconstruction_loss,
|
||||
mel_reconstruction_loss,
|
||||
stats,
|
||||
)
|
||||
|
||||
def _forward_discriminator(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
global_step: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
speech (Tensor): Speech waveform tensor (B, T_wav).
|
||||
speech_lengths (Tensor): Speech length tensor (B,).
|
||||
global_step (int): Global step.
|
||||
|
||||
Returns:
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
@ -206,37 +255,46 @@ class Encodec(nn.Module):
|
||||
)
|
||||
# calculate losses
|
||||
with autocast(enabled=False):
|
||||
loss = loss_dis(
|
||||
y,
|
||||
y_hat,
|
||||
fmap,
|
||||
fmap_hat,
|
||||
y_p,
|
||||
y_p_hat,
|
||||
fmap_p,
|
||||
fmap_p_hat,
|
||||
y_s,
|
||||
y_s_hat,
|
||||
fmap_s,
|
||||
fmap_s_hat,
|
||||
global_step,
|
||||
args=self.params,
|
||||
)
|
||||
(
|
||||
disc_stft_real_adv_loss,
|
||||
disc_stft_fake_adv_loss,
|
||||
) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat)
|
||||
(
|
||||
disc_period_real_adv_loss,
|
||||
disc_period_fake_adv_loss,
|
||||
) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat)
|
||||
(
|
||||
disc_scale_real_adv_loss,
|
||||
disc_scale_fake_adv_loss,
|
||||
) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat)
|
||||
|
||||
stats = dict(
|
||||
discriminator_loss=loss.item(),
|
||||
discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(),
|
||||
discriminator_period_real_adv_loss=disc_period_real_adv_loss.item(),
|
||||
discriminator_scale_real_adv_loss=disc_scale_real_adv_loss.item(),
|
||||
discriminator_stft_fake_adv_loss=disc_stft_fake_adv_loss.item(),
|
||||
discriminator_period_fake_adv_loss=disc_period_fake_adv_loss.item(),
|
||||
discriminator_scale_fake_adv_loss=disc_scale_fake_adv_loss.item(),
|
||||
)
|
||||
|
||||
# reset cache
|
||||
if reuse_cache or not self.training:
|
||||
self._cache = None
|
||||
|
||||
return loss, stats
|
||||
return (
|
||||
disc_stft_real_adv_loss,
|
||||
disc_stft_fake_adv_loss,
|
||||
disc_period_real_adv_loss,
|
||||
disc_period_fake_adv_loss,
|
||||
disc_scale_real_adv_loss,
|
||||
disc_scale_fake_adv_loss,
|
||||
stats,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
global_step: int,
|
||||
return_sample: bool,
|
||||
forward_generator: bool,
|
||||
):
|
||||
@ -244,14 +302,12 @@ class Encodec(nn.Module):
|
||||
return self._forward_generator(
|
||||
speech=speech,
|
||||
speech_lengths=speech_lengths,
|
||||
global_step=global_step,
|
||||
return_sample=return_sample,
|
||||
)
|
||||
else:
|
||||
return self._forward_discriminator(
|
||||
speech=speech,
|
||||
speech_lengths=speech_lengths,
|
||||
global_step=global_step,
|
||||
)
|
||||
|
||||
def encode(self, x, target_bw=None, st=None):
|
||||
|
@ -71,7 +71,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--target-bw",
|
||||
type=float,
|
||||
default=7.5,
|
||||
default=24,
|
||||
help="The target bandwidth for the generator",
|
||||
)
|
||||
|
||||
|
@ -1,8 +1,310 @@
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from lhotse.features.kaldi import Wav2LogFilterBank
|
||||
from torchaudio.transforms import MelSpectrogram
|
||||
|
||||
|
||||
class GeneratorAdversarialLoss(torch.nn.Module):
|
||||
"""Generator adversarial loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_discriminators: bool = True,
|
||||
loss_type: str = "hinge",
|
||||
):
|
||||
"""Initialize GeneratorAversarialLoss module.
|
||||
|
||||
Args:
|
||||
average_by_discriminators (bool): Whether to average the loss by
|
||||
the number of discriminators.
|
||||
loss_type (str): Loss type, "mse" or "hinge".
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.average_by_discriminators = average_by_discriminators
|
||||
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
||||
if loss_type == "mse":
|
||||
self.criterion = self._mse_loss
|
||||
else:
|
||||
self.criterion = self._hinge_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Calcualate generator adversarial loss.
|
||||
|
||||
Args:
|
||||
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
||||
outputs, list of discriminator outputs, or list of list of discriminator
|
||||
outputs..
|
||||
|
||||
Returns:
|
||||
Tensor: Generator adversarial loss value.
|
||||
|
||||
"""
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
adv_loss = 0.0
|
||||
for i, outputs_ in enumerate(outputs):
|
||||
if isinstance(outputs_, (tuple, list)):
|
||||
# NOTE(kan-bayashi): case including feature maps
|
||||
outputs_ = outputs_[-1]
|
||||
adv_loss += self.criterion(outputs_)
|
||||
if self.average_by_discriminators:
|
||||
adv_loss /= i + 1
|
||||
else:
|
||||
adv_loss = self.criterion(outputs)
|
||||
|
||||
return adv_loss
|
||||
|
||||
def _mse_loss(self, x):
|
||||
return F.mse_loss(x, x.new_ones(x.size()))
|
||||
|
||||
def _hinge_loss(self, x):
|
||||
return F.relu(1 - x).mean()
|
||||
|
||||
|
||||
class DiscriminatorAdversarialLoss(torch.nn.Module):
|
||||
"""Discriminator adversarial loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_discriminators: bool = True,
|
||||
loss_type: str = "hinge",
|
||||
):
|
||||
"""Initialize DiscriminatorAversarialLoss module.
|
||||
|
||||
Args:
|
||||
average_by_discriminators (bool): Whether to average the loss by
|
||||
the number of discriminators.
|
||||
loss_type (str): Loss type, "mse" or "hinge".
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.average_by_discriminators = average_by_discriminators
|
||||
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
||||
if loss_type == "mse":
|
||||
self.fake_criterion = self._mse_fake_loss
|
||||
self.real_criterion = self._mse_real_loss
|
||||
else:
|
||||
self.fake_criterion = self._hinge_fake_loss
|
||||
self.real_criterion = self._hinge_real_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
||||
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calcualate discriminator adversarial loss.
|
||||
|
||||
Args:
|
||||
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
||||
outputs, list of discriminator outputs, or list of list of discriminator
|
||||
outputs calculated from generator.
|
||||
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
||||
outputs, list of discriminator outputs, or list of list of discriminator
|
||||
outputs calculated from groundtruth.
|
||||
|
||||
Returns:
|
||||
Tensor: Discriminator real loss value.
|
||||
Tensor: Discriminator fake loss value.
|
||||
|
||||
"""
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
real_loss = 0.0
|
||||
fake_loss = 0.0
|
||||
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
||||
if isinstance(outputs_hat_, (tuple, list)):
|
||||
# NOTE(kan-bayashi): case including feature maps
|
||||
outputs_hat_ = outputs_hat_[-1]
|
||||
outputs_ = outputs_[-1]
|
||||
real_loss += self.real_criterion(outputs_)
|
||||
fake_loss += self.fake_criterion(outputs_hat_)
|
||||
if self.average_by_discriminators:
|
||||
fake_loss /= i + 1
|
||||
real_loss /= i + 1
|
||||
else:
|
||||
real_loss = self.real_criterion(outputs)
|
||||
fake_loss = self.fake_criterion(outputs_hat)
|
||||
|
||||
return real_loss, fake_loss
|
||||
|
||||
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.mse_loss(x, x.new_ones(x.size()))
|
||||
|
||||
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.mse_loss(x, x.new_zeros(x.size()))
|
||||
|
||||
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu(x.new_ones(x.size()) - x).mean()
|
||||
|
||||
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.relu(x.new_ones(x.size()) + x).mean()
|
||||
|
||||
|
||||
class FeatureMatchLoss(torch.nn.Module):
|
||||
"""Feature matching loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_layers: bool = True,
|
||||
average_by_discriminators: bool = True,
|
||||
include_final_outputs: bool = False,
|
||||
):
|
||||
"""Initialize FeatureMatchLoss module.
|
||||
|
||||
Args:
|
||||
average_by_layers (bool): Whether to average the loss by the number
|
||||
of layers.
|
||||
average_by_discriminators (bool): Whether to average the loss by
|
||||
the number of discriminators.
|
||||
include_final_outputs (bool): Whether to include the final output of
|
||||
each discriminator for loss calculation.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.average_by_layers = average_by_layers
|
||||
self.average_by_discriminators = average_by_discriminators
|
||||
self.include_final_outputs = include_final_outputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
||||
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
"""Calculate feature matching loss.
|
||||
|
||||
Args:
|
||||
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
||||
discriminator outputs or list of discriminator outputs calcuated
|
||||
from generator's outputs.
|
||||
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
||||
discriminator outputs or list of discriminator outputs calcuated
|
||||
from groundtruth..
|
||||
|
||||
Returns:
|
||||
Tensor: Feature matching loss value.
|
||||
|
||||
"""
|
||||
feat_match_loss = 0.0
|
||||
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
|
||||
feat_match_loss_ = 0.0
|
||||
if not self.include_final_outputs:
|
||||
feats_hat_ = feats_hat_[:-1]
|
||||
feats_ = feats_[:-1]
|
||||
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
||||
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
|
||||
if self.average_by_layers:
|
||||
feat_match_loss_ /= j + 1
|
||||
feat_match_loss += feat_match_loss_
|
||||
if self.average_by_discriminators:
|
||||
feat_match_loss /= i + 1
|
||||
|
||||
return feat_match_loss
|
||||
|
||||
|
||||
class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
||||
"""Mel Spec Reconstruction loss."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampling_rate: int = 22050,
|
||||
n_mels: int = 64,
|
||||
use_fft_mag: bool = True,
|
||||
return_mel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.wav_to_specs = []
|
||||
for i in range(5, 12):
|
||||
s = 2**i
|
||||
# self.wav_to_specs.append(
|
||||
# Wav2LogFilterBank(
|
||||
# sampling_rate=sampling_rate,
|
||||
# frame_length=s,
|
||||
# frame_shift=s // 4,
|
||||
# use_fft_mag=use_fft_mag,
|
||||
# num_filters=n_mels,
|
||||
# )
|
||||
# )
|
||||
self.wav_to_specs.append(
|
||||
MelSpectrogram(
|
||||
sample_rate=sampling_rate,
|
||||
n_fft=max(s, 512),
|
||||
win_length=s,
|
||||
hop_length=s // 4,
|
||||
n_mels=n_mels,
|
||||
)
|
||||
)
|
||||
self.return_mel = return_mel
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_hat: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""Calculate Mel-spectrogram loss.
|
||||
|
||||
Args:
|
||||
x_hat (Tensor): Generated waveform tensor (B, 1, T).
|
||||
x (Tensor): Groundtruth waveform tensor (B, 1, T).
|
||||
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
|
||||
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
|
||||
waveform.
|
||||
|
||||
Returns:
|
||||
Tensor: Mel-spectrogram loss value.
|
||||
|
||||
"""
|
||||
mel_loss = 0.0
|
||||
|
||||
for i, wav_to_spec in enumerate(self.wav_to_specs):
|
||||
s = 2 ** (i + 5)
|
||||
wav_to_spec.to(x.device)
|
||||
|
||||
mel_hat = wav_to_spec(x_hat.squeeze(1))
|
||||
mel = wav_to_spec(x.squeeze(1))
|
||||
|
||||
alpha = (s / 2) ** 0.5
|
||||
mel_loss += F.l1_loss(mel_hat, mel) + alpha * F.mse_loss(mel_hat, mel)
|
||||
|
||||
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
|
||||
# mel = self.wav_to_spec(x.squeeze(1))
|
||||
# mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel)
|
||||
|
||||
if self.return_mel:
|
||||
return mel_loss, (mel_hat, mel)
|
||||
|
||||
return mel_loss
|
||||
|
||||
|
||||
class WavReconstructionLoss(torch.nn.Module):
|
||||
"""Wav Reconstruction loss."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x_hat: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate wav loss.
|
||||
|
||||
Args:
|
||||
x_hat (Tensor): Generated waveform tensor (B, 1, T).
|
||||
x (Tensor): Groundtruth waveform tensor (B, 1, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Wav loss value.
|
||||
|
||||
"""
|
||||
wav_loss = F.mse_loss(x, x_hat)
|
||||
|
||||
return wav_loss
|
||||
|
||||
|
||||
def adversarial_g_loss(y_disc_gen):
|
||||
"""Hinge loss"""
|
||||
loss = 0.0
|
||||
@ -63,88 +365,12 @@ def reconstruction_loss(x, x_hat, args, eps=1e-7):
|
||||
return L
|
||||
|
||||
|
||||
def criterion_d(
|
||||
y_disc_r,
|
||||
y_disc_gen,
|
||||
fmap_r_det,
|
||||
fmap_gen_det,
|
||||
y_df_hat_r,
|
||||
y_df_hat_g,
|
||||
fmap_f_r,
|
||||
fmap_f_g,
|
||||
y_ds_hat_r,
|
||||
y_ds_hat_g,
|
||||
fmap_s_r,
|
||||
fmap_s_g,
|
||||
):
|
||||
"""Hinge Loss"""
|
||||
loss = 0.0
|
||||
loss1 = 0.0
|
||||
loss2 = 0.0
|
||||
loss3 = 0.0
|
||||
for i in range(len(y_disc_r)):
|
||||
loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[i]).mean()
|
||||
for i in range(len(y_df_hat_r)):
|
||||
loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[i]).mean()
|
||||
for i in range(len(y_ds_hat_r)):
|
||||
loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[i]).mean()
|
||||
|
||||
loss = (
|
||||
loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / len(y_ds_hat_r)
|
||||
) / 3.0
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def criterion_g(
|
||||
commit_loss,
|
||||
x,
|
||||
G_x,
|
||||
fmap_r,
|
||||
fmap_gen,
|
||||
y_disc_r,
|
||||
y_disc_gen,
|
||||
y_df_hat_r,
|
||||
y_df_hat_g,
|
||||
fmap_f_r,
|
||||
fmap_f_g,
|
||||
y_ds_hat_r,
|
||||
y_ds_hat_g,
|
||||
fmap_s_r,
|
||||
fmap_s_g,
|
||||
args,
|
||||
):
|
||||
adv_g_loss = adversarial_g_loss(y_disc_gen)
|
||||
feat_loss = (
|
||||
feature_loss(fmap_r, fmap_gen)
|
||||
+ sim_loss(y_disc_r, y_disc_gen)
|
||||
+ feature_loss(fmap_f_r, fmap_f_g)
|
||||
+ sim_loss(y_df_hat_r, y_df_hat_g)
|
||||
+ feature_loss(fmap_s_r, fmap_s_g)
|
||||
+ sim_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
) / 3.0
|
||||
rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args)
|
||||
total_loss = (
|
||||
args.lambda_com * commit_loss
|
||||
+ args.lambda_adv * adv_g_loss
|
||||
+ args.lambda_feat * feat_loss
|
||||
+ args.lambda_rec * rec_loss
|
||||
)
|
||||
return total_loss, adv_g_loss, feat_loss, rec_loss
|
||||
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
|
||||
def adopt_dis_weight(weight, global_step, threshold=0, value=0.0):
|
||||
if global_step % 3 == 0:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
|
||||
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
@ -166,7 +392,6 @@ def loss_g(
|
||||
fmap_hat,
|
||||
y,
|
||||
y_hat,
|
||||
global_step,
|
||||
y_df,
|
||||
y_df_hat,
|
||||
y_ds,
|
||||
@ -215,9 +440,10 @@ def loss_g(
|
||||
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
|
||||
d_weight = torch.tensor(1.0)
|
||||
|
||||
disc_factor = adopt_weight(
|
||||
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
|
||||
)
|
||||
# disc_factor = adopt_weight(
|
||||
# args.lambda_adv, global_step, threshold=args.discriminator_iter_start
|
||||
# )
|
||||
disc_factor = 1
|
||||
if disc_factor == 0.0:
|
||||
fm_loss_wt = 0
|
||||
else:
|
||||
@ -232,37 +458,9 @@ def loss_g(
|
||||
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
|
||||
|
||||
|
||||
def loss_dis(
|
||||
y,
|
||||
y_hat,
|
||||
fmap,
|
||||
fmap_hat,
|
||||
y_df,
|
||||
y_df_hat,
|
||||
fmap_f,
|
||||
fmap_f_hat,
|
||||
y_ds,
|
||||
y_ds_hat,
|
||||
fmap_s,
|
||||
fmap_s_hat,
|
||||
global_step,
|
||||
args,
|
||||
):
|
||||
disc_factor = adopt_weight(
|
||||
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
|
||||
)
|
||||
d_loss = disc_factor * criterion_d(
|
||||
y,
|
||||
y_hat,
|
||||
fmap,
|
||||
fmap_hat,
|
||||
y_df,
|
||||
y_df_hat,
|
||||
fmap_f,
|
||||
fmap_f_hat,
|
||||
y_ds,
|
||||
y_ds_hat,
|
||||
fmap_s,
|
||||
fmap_s_hat,
|
||||
)
|
||||
return d_loss
|
||||
if __name__ == "__main__":
|
||||
la = FeatureMatchLoss(average_by_layers=False, average_by_discriminators=False)
|
||||
aa = [torch.rand(192, 192) for _ in range(3)]
|
||||
bb = [torch.rand(192, 192) for _ in range(3)]
|
||||
print(la(bb, aa))
|
||||
print(feature_loss(aa, bb))
|
||||
|
@ -15,12 +15,13 @@ from codec_datamodule import LibriTTSCodecDataModule
|
||||
from encodec import Encodec
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from loss import adopt_weight
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from utils import MetricsTracker, plot_feature, save_checkpoint
|
||||
from utils import MetricsTracker, save_checkpoint
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -250,11 +251,26 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
from modules.seanet import SEANetDecoder, SEANetEncoder
|
||||
from quantization import ResidualVectorQuantizer
|
||||
|
||||
# generator_params = {
|
||||
# "generator_n_filters": 32,
|
||||
# "dimension": 512,
|
||||
# "ratios": [2, 2, 2, 4],
|
||||
# "target_bandwidths": [7.5, 15],
|
||||
# "bins": 1024,
|
||||
# }
|
||||
# discriminator_params = {
|
||||
# "stft_discriminator_n_filters": 32,
|
||||
# "discriminator_iter_start": 500,
|
||||
# }
|
||||
# inference_params = {
|
||||
# "target_bw": 7.5,
|
||||
# }
|
||||
|
||||
generator_params = {
|
||||
"generator_n_filters": 32,
|
||||
"dimension": 512,
|
||||
"ratios": [2, 2, 2, 4],
|
||||
"target_bandwidths": [7.5, 15],
|
||||
"ratios": [8, 5, 4, 2],
|
||||
"target_bandwidths": [1.5, 3, 6, 12, 24],
|
||||
"bins": 1024,
|
||||
}
|
||||
discriminator_params = {
|
||||
@ -262,7 +278,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
"discriminator_iter_start": 500,
|
||||
}
|
||||
inference_params = {
|
||||
"target_bw": 7.5,
|
||||
"target_bw": 12,
|
||||
}
|
||||
|
||||
params.update(generator_params)
|
||||
@ -419,36 +435,93 @@ def train_one_epoch(
|
||||
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
d_weight = adopt_weight(
|
||||
params.lambda_adv,
|
||||
params.batch_idx_train,
|
||||
threshold=params.discriminator_iter_start,
|
||||
)
|
||||
# forward discriminator
|
||||
loss_d, stats_d = model(
|
||||
(
|
||||
disc_stft_real_adv_loss,
|
||||
disc_stft_fake_adv_loss,
|
||||
disc_period_real_adv_loss,
|
||||
disc_period_fake_adv_loss,
|
||||
disc_scale_real_adv_loss,
|
||||
disc_scale_fake_adv_loss,
|
||||
stats_d,
|
||||
) = model(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
return_sample=False,
|
||||
forward_generator=False,
|
||||
)
|
||||
disc_loss = (
|
||||
(
|
||||
disc_stft_real_adv_loss
|
||||
+ disc_stft_fake_adv_loss
|
||||
+ disc_period_real_adv_loss
|
||||
+ disc_period_fake_adv_loss
|
||||
+ disc_scale_real_adv_loss
|
||||
+ disc_scale_fake_adv_loss
|
||||
)
|
||||
* d_weight
|
||||
/ 3
|
||||
)
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
# update discriminator
|
||||
optimizer_d.zero_grad()
|
||||
scaler.scale(loss_d).backward()
|
||||
scaler.scale(disc_loss).backward()
|
||||
scaler.step(optimizer_d)
|
||||
|
||||
with autocast(enabled=params.use_fp16):
|
||||
g_weight = adopt_weight(
|
||||
params.lambda_adv,
|
||||
params.batch_idx_train,
|
||||
threshold=params.discriminator_iter_start,
|
||||
)
|
||||
# forward generator
|
||||
loss_g, stats_g = model(
|
||||
(
|
||||
commit_loss,
|
||||
gen_stft_adv_loss,
|
||||
gen_period_adv_loss,
|
||||
gen_scale_adv_loss,
|
||||
feature_stft_loss,
|
||||
feature_period_loss,
|
||||
feature_scale_loss,
|
||||
wav_reconstruction_loss,
|
||||
mel_reconstruction_loss,
|
||||
stats_g,
|
||||
) = model(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
forward_generator=True,
|
||||
return_sample=params.batch_idx_train % params.log_interval == 0,
|
||||
)
|
||||
gen_adv_loss = (
|
||||
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
||||
* g_weight
|
||||
/ 3
|
||||
)
|
||||
feature_loss = (
|
||||
feature_stft_loss + feature_period_loss + feature_scale_loss
|
||||
) / 3
|
||||
reconstruction_loss = (
|
||||
params.lambda_wav * wav_reconstruction_loss
|
||||
+ mel_reconstruction_loss
|
||||
)
|
||||
gen_loss = (
|
||||
gen_adv_loss
|
||||
+ params.lambda_rec * reconstruction_loss
|
||||
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
|
||||
+ params.lambda_com * commit_loss
|
||||
)
|
||||
for k, v in stats_g.items():
|
||||
if "returned_sample" not in k:
|
||||
loss_info[k] = v * batch_size
|
||||
# update generator
|
||||
optimizer_g.zero_grad()
|
||||
scaler.scale(loss_g).backward()
|
||||
scaler.scale(gen_loss).backward()
|
||||
scaler.step(optimizer_g)
|
||||
scaler.update()
|
||||
|
||||
@ -619,27 +692,84 @@ def compute_validation_loss(
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
d_weight = adopt_weight(
|
||||
params.lambda_adv,
|
||||
params.batch_idx_train,
|
||||
threshold=params.discriminator_iter_start,
|
||||
)
|
||||
|
||||
# forward discriminator
|
||||
loss_d, stats_d = model(
|
||||
(
|
||||
disc_stft_real_adv_loss,
|
||||
disc_stft_fake_adv_loss,
|
||||
disc_period_real_adv_loss,
|
||||
disc_period_fake_adv_loss,
|
||||
disc_scale_real_adv_loss,
|
||||
disc_scale_fake_adv_loss,
|
||||
stats_d,
|
||||
) = model(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
return_sample=False,
|
||||
forward_generator=False,
|
||||
)
|
||||
assert loss_d.requires_grad is False
|
||||
disc_loss = (
|
||||
(
|
||||
disc_stft_real_adv_loss
|
||||
+ disc_stft_fake_adv_loss
|
||||
+ disc_period_real_adv_loss
|
||||
+ disc_period_fake_adv_loss
|
||||
+ disc_scale_real_adv_loss
|
||||
+ disc_scale_fake_adv_loss
|
||||
)
|
||||
* d_weight
|
||||
/ 3
|
||||
)
|
||||
assert disc_loss.requires_grad is False
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
g_weight = adopt_weight(
|
||||
params.lambda_adv,
|
||||
params.batch_idx_train,
|
||||
threshold=params.discriminator_iter_start,
|
||||
)
|
||||
# forward generator
|
||||
loss_g, stats_g = model(
|
||||
(
|
||||
commit_loss,
|
||||
gen_stft_adv_loss,
|
||||
gen_period_adv_loss,
|
||||
gen_scale_adv_loss,
|
||||
feature_stft_loss,
|
||||
feature_period_loss,
|
||||
feature_scale_loss,
|
||||
wav_reconstruction_loss,
|
||||
mel_reconstruction_loss,
|
||||
stats_g,
|
||||
) = model(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
forward_generator=True,
|
||||
return_sample=False,
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
gen_adv_loss = (
|
||||
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
||||
* g_weight
|
||||
/ 3
|
||||
)
|
||||
feature_loss = (
|
||||
feature_stft_loss + feature_period_loss + feature_scale_loss
|
||||
) / 3
|
||||
reconstruction_loss = (
|
||||
params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss
|
||||
)
|
||||
gen_loss = (
|
||||
gen_adv_loss
|
||||
+ params.lambda_rec * reconstruction_loss
|
||||
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
|
||||
+ params.lambda_com * commit_loss
|
||||
)
|
||||
assert gen_loss.requires_grad is False
|
||||
for k, v in stats_g.items():
|
||||
if "returned_sample" not in k:
|
||||
loss_info[k] = v * batch_size
|
||||
@ -691,24 +821,74 @@ def scan_pessimistic_batches_for_oom(
|
||||
try:
|
||||
# for discriminator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
loss_d, stats_d = model(
|
||||
(
|
||||
disc_stft_real_adv_loss,
|
||||
disc_stft_fake_adv_loss,
|
||||
disc_period_real_adv_loss,
|
||||
disc_period_fake_adv_loss,
|
||||
disc_scale_real_adv_loss,
|
||||
disc_scale_fake_adv_loss,
|
||||
stats_d,
|
||||
) = model(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
return_sample=False,
|
||||
forward_generator=False,
|
||||
)
|
||||
loss_d = (
|
||||
(
|
||||
disc_stft_real_adv_loss
|
||||
+ disc_stft_fake_adv_loss
|
||||
+ disc_period_real_adv_loss
|
||||
+ disc_period_fake_adv_loss
|
||||
+ disc_scale_real_adv_loss
|
||||
+ disc_scale_fake_adv_loss
|
||||
)
|
||||
* adopt_weight(
|
||||
params.lambda_adv,
|
||||
params.batch_idx_train,
|
||||
threshold=params.discriminator_iter_start,
|
||||
)
|
||||
/ 3
|
||||
)
|
||||
optimizer_d.zero_grad()
|
||||
loss_d.backward()
|
||||
# for generator
|
||||
with autocast(enabled=params.use_fp16):
|
||||
loss_g, stats_g = model(
|
||||
(
|
||||
commit_loss,
|
||||
gen_stft_adv_loss,
|
||||
gen_period_adv_loss,
|
||||
gen_scale_adv_loss,
|
||||
feature_stft_loss,
|
||||
feature_period_loss,
|
||||
feature_scale_loss,
|
||||
wav_reconstruction_loss,
|
||||
mel_reconstruction_loss,
|
||||
stats_g,
|
||||
) = model(
|
||||
speech=audio,
|
||||
speech_lengths=audio_lens,
|
||||
forward_generator=True,
|
||||
global_step=params.batch_idx_train,
|
||||
return_sample=False,
|
||||
)
|
||||
loss_g = (
|
||||
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
||||
* adopt_weight(
|
||||
params.lambda_adv,
|
||||
params.batch_idx_train,
|
||||
threshold=params.discriminator_iter_start,
|
||||
)
|
||||
/ 3
|
||||
+ params.lambda_rec
|
||||
* (
|
||||
params.lambda_wav * wav_reconstruction_loss
|
||||
+ mel_reconstruction_loss
|
||||
)
|
||||
+ params.lambda_feat
|
||||
* (feature_stft_loss + feature_period_loss + feature_scale_loss)
|
||||
+ params.lambda_com * commit_loss
|
||||
)
|
||||
optimizer_g.zero_grad()
|
||||
loss_g.backward()
|
||||
except Exception as e:
|
||||
|
Loading…
x
Reference in New Issue
Block a user