# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """HiFiGAN-related loss modules. This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. """ from typing import List, Tuple, Union import torch import torch.distributions as D import torch.nn.functional as F from lhotse.features.kaldi import Wav2LogFilterBank class GeneratorAdversarialLoss(torch.nn.Module): """Generator adversarial loss module.""" def __init__( self, average_by_discriminators: bool = True, loss_type: str = "mse", ): """Initialize GeneratorAversarialLoss module. Args: average_by_discriminators (bool): Whether to average the loss by the number of discriminators. loss_type (str): Loss type, "mse" or "hinge". """ super().__init__() self.average_by_discriminators = average_by_discriminators assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." if loss_type == "mse": self.criterion = self._mse_loss else: self.criterion = self._hinge_loss def forward( self, outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], ) -> torch.Tensor: """Calcualate generator adversarial loss. Args: outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs.. Returns: Tensor: Generator adversarial loss value. """ if isinstance(outputs, (tuple, list)): adv_loss = 0.0 for i, outputs_ in enumerate(outputs): if isinstance(outputs_, (tuple, list)): # NOTE(kan-bayashi): case including feature maps outputs_ = outputs_[-1] adv_loss += self.criterion(outputs_) if self.average_by_discriminators: adv_loss /= i + 1 else: adv_loss = self.criterion(outputs) return adv_loss def _mse_loss(self, x): return F.mse_loss(x, x.new_ones(x.size())) def _hinge_loss(self, x): return -x.mean() class DiscriminatorAdversarialLoss(torch.nn.Module): """Discriminator adversarial loss module.""" def __init__( self, average_by_discriminators: bool = True, loss_type: str = "mse", ): """Initialize DiscriminatorAversarialLoss module. Args: average_by_discriminators (bool): Whether to average the loss by the number of discriminators. loss_type (str): Loss type, "mse" or "hinge". """ super().__init__() self.average_by_discriminators = average_by_discriminators assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." if loss_type == "mse": self.fake_criterion = self._mse_fake_loss self.real_criterion = self._mse_real_loss else: self.fake_criterion = self._hinge_fake_loss self.real_criterion = self._hinge_real_loss def forward( self, outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """Calcualate discriminator adversarial loss. Args: outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs calculated from generator. outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator outputs, list of discriminator outputs, or list of list of discriminator outputs calculated from groundtruth. Returns: Tensor: Discriminator real loss value. Tensor: Discriminator fake loss value. """ if isinstance(outputs, (tuple, list)): real_loss = 0.0 fake_loss = 0.0 for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): if isinstance(outputs_hat_, (tuple, list)): # NOTE(kan-bayashi): case including feature maps outputs_hat_ = outputs_hat_[-1] outputs_ = outputs_[-1] real_loss += self.real_criterion(outputs_) fake_loss += self.fake_criterion(outputs_hat_) if self.average_by_discriminators: fake_loss /= i + 1 real_loss /= i + 1 else: real_loss = self.real_criterion(outputs) fake_loss = self.fake_criterion(outputs_hat) return real_loss, fake_loss def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, x.new_ones(x.size())) def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, x.new_zeros(x.size())) def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) class FeatureMatchLoss(torch.nn.Module): """Feature matching loss module.""" def __init__( self, average_by_layers: bool = True, average_by_discriminators: bool = True, include_final_outputs: bool = False, ): """Initialize FeatureMatchLoss module. Args: average_by_layers (bool): Whether to average the loss by the number of layers. average_by_discriminators (bool): Whether to average the loss by the number of discriminators. include_final_outputs (bool): Whether to include the final output of each discriminator for loss calculation. """ super().__init__() self.average_by_layers = average_by_layers self.average_by_discriminators = average_by_discriminators self.include_final_outputs = include_final_outputs def forward( self, feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], ) -> torch.Tensor: """Calculate feature matching loss. Args: feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of discriminator outputs or list of discriminator outputs calcuated from generator's outputs. feats (Union[List[List[Tensor]], List[Tensor]]): List of list of discriminator outputs or list of discriminator outputs calcuated from groundtruth.. Returns: Tensor: Feature matching loss value. """ feat_match_loss = 0.0 for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): feat_match_loss_ = 0.0 if not self.include_final_outputs: feats_hat_ = feats_hat_[:-1] feats_ = feats_[:-1] for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) if self.average_by_layers: feat_match_loss_ /= j + 1 feat_match_loss += feat_match_loss_ if self.average_by_discriminators: feat_match_loss /= i + 1 return feat_match_loss class MelSpectrogramLoss(torch.nn.Module): """Mel-spectrogram loss.""" def __init__( self, sampling_rate: int = 22050, frame_length: int = 1024, # in samples frame_shift: int = 256, # in samples n_mels: int = 80, use_fft_mag: bool = True, ): super().__init__() self.wav_to_mel = Wav2LogFilterBank( sampling_rate=sampling_rate, frame_length=frame_length / sampling_rate, # in second frame_shift=frame_shift / sampling_rate, # in second use_fft_mag=use_fft_mag, num_filters=n_mels, ) def forward( self, y_hat: torch.Tensor, y: torch.Tensor, return_mel: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: """Calculate Mel-spectrogram loss. Args: y_hat (Tensor): Generated waveform tensor (B, 1, T). y (Tensor): Groundtruth waveform tensor (B, 1, T). spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth waveform. Returns: Tensor: Mel-spectrogram loss value. """ mel_hat = self.wav_to_mel(y_hat.squeeze(1)) mel = self.wav_to_mel(y.squeeze(1)) mel_loss = F.l1_loss(mel_hat, mel) if return_mel: return mel_loss, (mel_hat, mel) return mel_loss # from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py """VITS-related loss modules. This code is based on https://github.com/jaywalnut310/vits. """ class KLDivergenceLoss(torch.nn.Module): """KL divergence loss.""" def forward( self, z_p: torch.Tensor, logs_q: torch.Tensor, m_p: torch.Tensor, logs_p: torch.Tensor, z_mask: torch.Tensor, ) -> torch.Tensor: """Calculate KL divergence loss. Args: z_p (Tensor): Flow hidden representation (B, H, T_feats). logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). z_mask (Tensor): Mask tensor (B, 1, T_feats). Returns: Tensor: KL divergence loss. """ z_p = z_p.float() logs_q = logs_q.float() m_p = m_p.float() logs_p = logs_p.float() z_mask = z_mask.float() kl = logs_p - logs_q - 0.5 kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) kl = torch.sum(kl * z_mask) loss = kl / torch.sum(z_mask) return loss class KLDivergenceLossWithoutFlow(torch.nn.Module): """KL divergence loss without flow.""" def forward( self, m_q: torch.Tensor, logs_q: torch.Tensor, m_p: torch.Tensor, logs_p: torch.Tensor, ) -> torch.Tensor: """Calculate KL divergence loss without flow. Args: m_q (Tensor): Posterior encoder projected mean (B, H, T_feats). logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). """ posterior_norm = D.Normal(m_q, torch.exp(logs_q)) prior_norm = D.Normal(m_p, torch.exp(logs_p)) loss = D.kl_divergence(posterior_norm, prior_norm).mean() return loss