2024-10-07 01:03:26 +08:00

488 lines
16 KiB
Python

from typing import List, Tuple, Union
import torch
import torch.nn.functional as F
from lhotse.features.kaldi import Wav2LogFilterBank
from torchaudio.transforms import MelSpectrogram
class GeneratorAdversarialLoss(torch.nn.Module):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "hinge",
):
"""Initialize GeneratorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(
self,
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""Calcualate generator adversarial loss.
Args:
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs..
Returns:
Tensor: Generator adversarial loss value.
"""
adv_loss = 0.0
if isinstance(outputs, (tuple, list)):
for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_ = outputs_[-1]
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
for i, outputs_ in enumerate(outputs):
adv_loss += self.criterion(outputs_)
adv_loss /= i + 1
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _hinge_loss(self, x):
return F.relu(1 - x).mean()
class DiscriminatorAdversarialLoss(torch.nn.Module):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators: bool = True,
loss_type: str = "hinge",
):
"""Initialize DiscriminatorAversarialLoss module.
Args:
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
loss_type (str): Loss type, "mse" or "hinge".
"""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
else:
self.fake_criterion = self._hinge_fake_loss
self.real_criterion = self._hinge_real_loss
def forward(
self,
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from generator.
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
outputs, list of discriminator outputs, or list of list of discriminator
outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
real_loss = 0.0
fake_loss = 0.0
if isinstance(outputs, (tuple, list)):
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
fake_loss /= i + 1
real_loss /= i + 1
return real_loss, fake_loss
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_ones(x.size()))
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(torch.ones_like(x) - x).mean()
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.relu(torch.ones_like(x) + x).mean()
class FeatureLoss(torch.nn.Module):
"""Feature loss module."""
def __init__(
self,
average_by_layers: bool = True,
average_by_discriminators: bool = True,
include_final_outputs: bool = True,
):
"""Initialize FeatureMatchLoss module.
Args:
average_by_layers (bool): Whether to average the loss by the number
of layers.
average_by_discriminators (bool): Whether to average the loss by
the number of discriminators.
include_final_outputs (bool): Whether to include the final output of
each discriminator for loss calculation.
"""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
self.include_final_outputs = include_final_outputs
def forward(
self,
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
) -> torch.Tensor:
"""Calculate feature matching loss.
Args:
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from generator's outputs.
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
discriminator outputs or list of discriminator outputs calcuated
from groundtruth..
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
if not self.include_final_outputs:
feats_hat_ = feats_hat_[:-1]
feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += (
F.l1_loss(feat_hat_, feat_.detach()) / (feat_.detach().abs().mean())
).mean()
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
class MelSpectrogramReconstructionLoss(torch.nn.Module):
"""Mel Spec Reconstruction loss."""
def __init__(
self,
sampling_rate: int = 22050,
n_mels: int = 64,
use_fft_mag: bool = True,
return_mel: bool = False,
):
super().__init__()
self.wav_to_specs = []
for i in range(5, 12):
s = 2**i
# self.wav_to_specs.append(
# Wav2LogFilterBank(
# sampling_rate=sampling_rate,
# frame_length=s,
# frame_shift=s // 4,
# use_fft_mag=use_fft_mag,
# num_filters=n_mels,
# )
# )
self.wav_to_specs.append(
MelSpectrogram(
sample_rate=sampling_rate,
n_fft=max(s, 512),
win_length=s,
hop_length=s // 4,
n_mels=n_mels,
)
)
self.return_mel = return_mel
def forward(
self,
x_hat: torch.Tensor,
x: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
"""Calculate Mel-spectrogram loss.
Args:
x_hat (Tensor): Generated waveform tensor (B, 1, T).
x (Tensor): Groundtruth waveform tensor (B, 1, T).
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
waveform.
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_loss = 0.0
for i, wav_to_spec in enumerate(self.wav_to_specs):
s = 2 ** (i + 5)
wav_to_spec.to(x.device)
mel_hat = wav_to_spec(x_hat.squeeze(1))
mel = wav_to_spec(x.squeeze(1))
mel_loss += (
F.l1_loss(mel_hat, mel, reduce=True, reduction="mean")
+ (
(
(torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7))
** 2
).mean(dim=-2)
** 0.5
).mean()
)
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
# mel = self.wav_to_spec(x.squeeze(1))
# mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel)
if self.return_mel:
return mel_loss, (mel_hat, mel)
return mel_loss
class WavReconstructionLoss(torch.nn.Module):
"""Wav Reconstruction loss."""
def __init__(self):
super().__init__()
def forward(
self,
x_hat: torch.Tensor,
x: torch.Tensor,
) -> torch.Tensor:
"""Calculate wav loss.
Args:
x_hat (Tensor): Generated waveform tensor (B, 1, T).
x (Tensor): Groundtruth waveform tensor (B, 1, T).
Returns:
Tensor: Wav loss value.
"""
wav_loss = F.l1_loss(x, x_hat)
return wav_loss
def adversarial_g_loss(y_disc_gen):
"""Hinge loss"""
loss = 0.0
for i in range(len(y_disc_gen)):
stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze()
loss += stft_loss
return loss / len(y_disc_gen)
def feature_loss(fmap_r, fmap_gen):
loss = 0.0
for i in range(len(fmap_r)):
for j in range(len(fmap_r[i])):
stft_loss = (
(fmap_r[i][j] - fmap_gen[i][j]).abs() / (fmap_r[i][j].abs().mean())
).mean()
loss += stft_loss
return loss / (len(fmap_r) * len(fmap_r[0]))
def sim_loss(y_disc_r, y_disc_gen):
loss = 0.0
for i in range(len(y_disc_r)):
loss += F.mse_loss(y_disc_r[i], y_disc_gen[i])
return loss / len(y_disc_r)
def reconstruction_loss(x, x_hat, args, eps=1e-7):
# NOTE (lsx): hard-coded now
L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss
# loss_sisnr = sisnr_loss(G_x, x) #
# L += 0.01*loss_sisnr
# 2^6=64 -> 2^10=1024
# NOTE (lsx): add 2^11
for i in range(6, 12):
# for i in range(5, 12): # Encodec setting
s = 2**i
melspec = MelSpectrogram(
sample_rate=args.sampling_rate,
n_fft=max(s, 512),
win_length=s,
hop_length=s // 4,
n_mels=64,
wkwargs={"device": x_hat.device},
).to(x_hat.device)
S_x = melspec(x)
S_x_hat = melspec(x_hat)
l1_loss = (S_x - S_x_hat).abs().mean()
l2_loss = (
((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean(
dim=-2
)
** 0.5
).mean()
alpha = (s / 2) ** 0.5
L += l1_loss + alpha * l2_loss
return L
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
print("last_layer cannot be none")
assert 1 == 2
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 1.0, 1.0).detach()
d_weight = d_weight * args.lambda_adv
return d_weight
def loss_g(
codebook_loss,
speech,
speech_hat,
fmap,
fmap_hat,
y,
y_hat,
y_df,
y_df_hat,
y_ds,
y_ds_hat,
fmap_f,
fmap_f_hat,
fmap_s,
fmap_s_hat,
args=None,
):
"""
args:
codebook_loss: commit loss.
speech: ground-truth wav.
speech_hat: reconstructed wav.
fmap: real stft-D feature map.
fmap_hat: fake stft-D feature map.
y: real stft-D logits.
y_hat: fake stft-D logits.
global_step: global training step.
y_df: real MPD logits.
y_df_hat: fake MPD logits.
y_ds: real MSD logits.
y_ds_hat: fake MSD logits.
fmap_f: real MPD feature map.
fmap_f_hat: fake MPD feature map.
fmap_s: real MSD feature map.
fmap_s_hat: fake MSD feature map.
"""
rec_loss = reconstruction_loss(speech.contiguous(), speech_hat.contiguous(), args)
adv_g_loss = adversarial_g_loss(y_hat)
adv_mpd_loss = adversarial_g_loss(y_df_hat)
adv_msd_loss = adversarial_g_loss(y_ds_hat)
adv_loss = (
adv_g_loss + adv_mpd_loss + adv_msd_loss
) / 3.0 # NOTE(lsx): need to divide by 3?
feat_loss = feature_loss(
fmap, fmap_hat
) # + sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits?
feat_loss_mpd = feature_loss(
fmap_f, fmap_f_hat
) # + sim_loss(y_df_hat_r, y_df_hat_g)
feat_loss_msd = feature_loss(
fmap_s, fmap_s_hat
) # + sim_loss(y_ds_hat_r, y_ds_hat_g)
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
d_weight = torch.tensor(1.0)
# disc_factor = adopt_weight(
# args.lambda_adv, global_step, threshold=args.discriminator_iter_start
# )
disc_factor = 1
if disc_factor == 0.0:
fm_loss_wt = 0
else:
fm_loss_wt = args.lambda_feat
loss = (
rec_loss
+ d_weight * disc_factor * adv_loss
+ fm_loss_wt * feat_loss_tot
+ args.lambda_com * codebook_loss
)
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
if __name__ == "__main__":
# la = FeatureLoss(average_by_layers=True, average_by_discriminators=True)
# aa = [torch.rand(192, 192) for _ in range(3)]
# bb = [torch.rand(192, 192) for _ in range(3)]
# print(la(bb, aa))
# print(feature_loss(aa, bb))
la = GeneratorAdversarialLoss(average_by_discriminators=True, loss_type="hinge")
aa = torch.Tensor([0.1, 0.2, 0.3, 0.4])
bb = torch.Tensor([0.4, 0.3, 0.2, 0.1])
print(la(aa))
print(adversarial_g_loss(aa))
print(la(bb))
print(adversarial_g_loss(bb))