mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 02:22:17 +00:00
322 lines
11 KiB
Python
322 lines
11 KiB
Python
# Modified from egs/ljspeech/TTS/vits/loss.py by: Zengrui JIN (Tsinghua University)
|
|
# original implementation is from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
|
|
|
|
# Copyright 2021 Tomoki Hayashi
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
"""Encodec-related loss modules.
|
|
|
|
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
|
|
|
"""
|
|
|
|
from typing import List, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torchaudio.transforms import MelSpectrogram
|
|
|
|
|
|
class GeneratorAdversarialLoss(torch.nn.Module):
|
|
"""Generator adversarial loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
average_by_discriminators: bool = True,
|
|
loss_type: str = "hinge",
|
|
):
|
|
"""Initialize GeneratorAversarialLoss module.
|
|
|
|
Args:
|
|
average_by_discriminators (bool): Whether to average the loss by
|
|
the number of discriminators.
|
|
loss_type (str): Loss type, "mse" or "hinge".
|
|
|
|
"""
|
|
super().__init__()
|
|
self.average_by_discriminators = average_by_discriminators
|
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
|
if loss_type == "mse":
|
|
self.criterion = self._mse_loss
|
|
else:
|
|
self.criterion = self._hinge_loss
|
|
|
|
def forward(
|
|
self,
|
|
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
) -> torch.Tensor:
|
|
"""Calcualate generator adversarial loss.
|
|
|
|
Args:
|
|
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
|
outputs, list of discriminator outputs, or list of list of discriminator
|
|
outputs..
|
|
|
|
Returns:
|
|
Tensor: Generator adversarial loss value.
|
|
|
|
"""
|
|
adv_loss = 0.0
|
|
if isinstance(outputs, (tuple, list)):
|
|
for i, outputs_ in enumerate(outputs):
|
|
if isinstance(outputs_, (tuple, list)):
|
|
# NOTE(kan-bayashi): case including feature maps
|
|
outputs_ = outputs_[-1]
|
|
adv_loss += self.criterion(outputs_)
|
|
if self.average_by_discriminators:
|
|
adv_loss /= i + 1
|
|
else:
|
|
for i, outputs_ in enumerate(outputs):
|
|
adv_loss += self.criterion(outputs_)
|
|
adv_loss /= i + 1
|
|
return adv_loss
|
|
|
|
def _mse_loss(self, x):
|
|
return F.mse_loss(x, x.new_ones(x.size()))
|
|
|
|
def _hinge_loss(self, x):
|
|
return F.relu(1 - x).mean()
|
|
|
|
|
|
class DiscriminatorAdversarialLoss(torch.nn.Module):
|
|
"""Discriminator adversarial loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
average_by_discriminators: bool = True,
|
|
loss_type: str = "hinge",
|
|
):
|
|
"""Initialize DiscriminatorAversarialLoss module.
|
|
|
|
Args:
|
|
average_by_discriminators (bool): Whether to average the loss by
|
|
the number of discriminators.
|
|
loss_type (str): Loss type, "mse" or "hinge".
|
|
|
|
"""
|
|
super().__init__()
|
|
self.average_by_discriminators = average_by_discriminators
|
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
|
if loss_type == "mse":
|
|
self.fake_criterion = self._mse_fake_loss
|
|
self.real_criterion = self._mse_real_loss
|
|
else:
|
|
self.fake_criterion = self._hinge_fake_loss
|
|
self.real_criterion = self._hinge_real_loss
|
|
|
|
def forward(
|
|
self,
|
|
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Calcualate discriminator adversarial loss.
|
|
|
|
Args:
|
|
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
|
outputs, list of discriminator outputs, or list of list of discriminator
|
|
outputs calculated from generator.
|
|
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
|
outputs, list of discriminator outputs, or list of list of discriminator
|
|
outputs calculated from groundtruth.
|
|
|
|
Returns:
|
|
Tensor: Discriminator real loss value.
|
|
Tensor: Discriminator fake loss value.
|
|
|
|
"""
|
|
real_loss = 0.0
|
|
fake_loss = 0.0
|
|
if isinstance(outputs, (tuple, list)):
|
|
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
|
if isinstance(outputs_hat_, (tuple, list)):
|
|
# NOTE(kan-bayashi): case including feature maps
|
|
outputs_hat_ = outputs_hat_[-1]
|
|
outputs_ = outputs_[-1]
|
|
real_loss += self.real_criterion(outputs_)
|
|
fake_loss += self.fake_criterion(outputs_hat_)
|
|
if self.average_by_discriminators:
|
|
fake_loss /= i + 1
|
|
real_loss /= i + 1
|
|
else:
|
|
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
|
real_loss += self.real_criterion(outputs_)
|
|
fake_loss += self.fake_criterion(outputs_hat_)
|
|
fake_loss /= i + 1
|
|
real_loss /= i + 1
|
|
|
|
return real_loss, fake_loss
|
|
|
|
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
return F.mse_loss(x, x.new_ones(x.size()))
|
|
|
|
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
return F.mse_loss(x, x.new_zeros(x.size()))
|
|
|
|
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
return F.relu(torch.ones_like(x) - x).mean()
|
|
|
|
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
return F.relu(torch.ones_like(x) + x).mean()
|
|
|
|
|
|
class FeatureLoss(torch.nn.Module):
|
|
"""Feature loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
average_by_layers: bool = True,
|
|
average_by_discriminators: bool = True,
|
|
include_final_outputs: bool = True,
|
|
):
|
|
"""Initialize FeatureMatchLoss module.
|
|
|
|
Args:
|
|
average_by_layers (bool): Whether to average the loss by the number
|
|
of layers.
|
|
average_by_discriminators (bool): Whether to average the loss by
|
|
the number of discriminators.
|
|
include_final_outputs (bool): Whether to include the final output of
|
|
each discriminator for loss calculation.
|
|
|
|
"""
|
|
super().__init__()
|
|
self.average_by_layers = average_by_layers
|
|
self.average_by_discriminators = average_by_discriminators
|
|
self.include_final_outputs = include_final_outputs
|
|
|
|
def forward(
|
|
self,
|
|
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
|
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
|
) -> torch.Tensor:
|
|
"""Calculate feature matching loss.
|
|
|
|
Args:
|
|
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
|
discriminator outputs or list of discriminator outputs calcuated
|
|
from generator's outputs.
|
|
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
|
discriminator outputs or list of discriminator outputs calcuated
|
|
from groundtruth..
|
|
|
|
Returns:
|
|
Tensor: Feature matching loss value.
|
|
|
|
"""
|
|
feat_match_loss = 0.0
|
|
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
|
|
feat_match_loss_ = 0.0
|
|
if not self.include_final_outputs:
|
|
feats_hat_ = feats_hat_[:-1]
|
|
feats_ = feats_[:-1]
|
|
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
|
feat_match_loss_ += (
|
|
F.l1_loss(feat_hat_, feat_.detach()) / (feat_.detach().abs().mean())
|
|
).mean()
|
|
if self.average_by_layers:
|
|
feat_match_loss_ /= j + 1
|
|
feat_match_loss += feat_match_loss_
|
|
if self.average_by_discriminators:
|
|
feat_match_loss /= i + 1
|
|
|
|
return feat_match_loss
|
|
|
|
|
|
class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
|
"""Mel Spec Reconstruction loss."""
|
|
|
|
def __init__(
|
|
self,
|
|
sampling_rate: int = 22050,
|
|
n_mels: int = 64,
|
|
use_fft_mag: bool = True,
|
|
return_mel: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.wav_to_specs = []
|
|
for i in range(5, 12):
|
|
s = 2**i
|
|
self.wav_to_specs.append(
|
|
MelSpectrogram(
|
|
sample_rate=sampling_rate,
|
|
n_fft=max(s, 512),
|
|
win_length=s,
|
|
hop_length=s // 4,
|
|
n_mels=n_mels,
|
|
)
|
|
)
|
|
self.return_mel = return_mel
|
|
|
|
def forward(
|
|
self,
|
|
x_hat: torch.Tensor,
|
|
x: torch.Tensor,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
|
"""Calculate Mel-spectrogram loss.
|
|
|
|
Args:
|
|
x_hat (Tensor): Generated waveform tensor (B, 1, T).
|
|
x (Tensor): Groundtruth waveform tensor (B, 1, T).
|
|
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
|
|
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
|
|
waveform.
|
|
|
|
Returns:
|
|
Tensor: Mel-spectrogram loss value.
|
|
|
|
"""
|
|
mel_loss = 0.0
|
|
|
|
for i, wav_to_spec in enumerate(self.wav_to_specs):
|
|
s = 2 ** (i + 5)
|
|
wav_to_spec.to(x.device)
|
|
|
|
mel_hat = wav_to_spec(x_hat.squeeze(1))
|
|
mel = wav_to_spec(x.squeeze(1))
|
|
|
|
mel_loss += (
|
|
F.l1_loss(mel_hat, mel, reduce=True, reduction="mean")
|
|
+ (
|
|
(
|
|
(torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7))
|
|
** 2
|
|
).mean(dim=-2)
|
|
** 0.5
|
|
).mean()
|
|
)
|
|
|
|
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
|
|
# mel = self.wav_to_spec(x.squeeze(1))
|
|
# mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel)
|
|
|
|
if self.return_mel:
|
|
return mel_loss, (mel_hat, mel)
|
|
|
|
return mel_loss
|
|
|
|
|
|
class WavReconstructionLoss(torch.nn.Module):
|
|
"""Wav Reconstruction loss."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
x_hat: torch.Tensor,
|
|
x: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Calculate wav loss.
|
|
|
|
Args:
|
|
x_hat (Tensor): Generated waveform tensor (B, 1, T).
|
|
x (Tensor): Groundtruth waveform tensor (B, 1, T).
|
|
|
|
Returns:
|
|
Tensor: Wav loss value.
|
|
|
|
"""
|
|
wav_loss = F.l1_loss(x, x_hat)
|
|
|
|
return wav_loss
|