mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
* init * isort formatted * minor updates * Create shared * Update prepare_tokens_vctk.py * Update prepare_tokens_vctk.py * Update prepare_tokens_vctk.py * Update prepare.sh * updated * Update train.py * Update train.py * Update tts_datamodule.py * Update train.py * Update train.py * Update train.py * Update train.py * Update train.py * Update train.py * fixed formatting issue * Update infer.py * removed redundant files * Create monotonic_align * removed redundant files * created symlinks * Update prepare.sh * minor adjustments * Create requirements_tts.txt * Update requirements_tts.txt added version constraints * Update infer.py * Update infer.py * Update infer.py * updated docs * Update export-onnx.py * Update export-onnx.py * Update test_onnx.py * updated requirements.txt * Update test_onnx.py * Update test_onnx.py * docs updated * docs fixed * minor updates
336 lines
11 KiB
Python
336 lines
11 KiB
Python
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
|
|
|
|
# Copyright 2021 Tomoki Hayashi
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
"""HiFiGAN-related loss modules.
|
|
|
|
This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
|
|
|
|
"""
|
|
|
|
from typing import List, Tuple, Union
|
|
|
|
import torch
|
|
import torch.distributions as D
|
|
import torch.nn.functional as F
|
|
from lhotse.features.kaldi import Wav2LogFilterBank
|
|
|
|
|
|
class GeneratorAdversarialLoss(torch.nn.Module):
|
|
"""Generator adversarial loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
average_by_discriminators: bool = True,
|
|
loss_type: str = "mse",
|
|
):
|
|
"""Initialize GeneratorAversarialLoss module.
|
|
|
|
Args:
|
|
average_by_discriminators (bool): Whether to average the loss by
|
|
the number of discriminators.
|
|
loss_type (str): Loss type, "mse" or "hinge".
|
|
|
|
"""
|
|
super().__init__()
|
|
self.average_by_discriminators = average_by_discriminators
|
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
|
if loss_type == "mse":
|
|
self.criterion = self._mse_loss
|
|
else:
|
|
self.criterion = self._hinge_loss
|
|
|
|
def forward(
|
|
self,
|
|
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
) -> torch.Tensor:
|
|
"""Calcualate generator adversarial loss.
|
|
|
|
Args:
|
|
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
|
outputs, list of discriminator outputs, or list of list of discriminator
|
|
outputs..
|
|
|
|
Returns:
|
|
Tensor: Generator adversarial loss value.
|
|
|
|
"""
|
|
if isinstance(outputs, (tuple, list)):
|
|
adv_loss = 0.0
|
|
for i, outputs_ in enumerate(outputs):
|
|
if isinstance(outputs_, (tuple, list)):
|
|
# NOTE(kan-bayashi): case including feature maps
|
|
outputs_ = outputs_[-1]
|
|
adv_loss += self.criterion(outputs_)
|
|
if self.average_by_discriminators:
|
|
adv_loss /= i + 1
|
|
else:
|
|
adv_loss = self.criterion(outputs)
|
|
|
|
return adv_loss
|
|
|
|
def _mse_loss(self, x):
|
|
return F.mse_loss(x, x.new_ones(x.size()))
|
|
|
|
def _hinge_loss(self, x):
|
|
return -x.mean()
|
|
|
|
|
|
class DiscriminatorAdversarialLoss(torch.nn.Module):
|
|
"""Discriminator adversarial loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
average_by_discriminators: bool = True,
|
|
loss_type: str = "mse",
|
|
):
|
|
"""Initialize DiscriminatorAversarialLoss module.
|
|
|
|
Args:
|
|
average_by_discriminators (bool): Whether to average the loss by
|
|
the number of discriminators.
|
|
loss_type (str): Loss type, "mse" or "hinge".
|
|
|
|
"""
|
|
super().__init__()
|
|
self.average_by_discriminators = average_by_discriminators
|
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
|
if loss_type == "mse":
|
|
self.fake_criterion = self._mse_fake_loss
|
|
self.real_criterion = self._mse_real_loss
|
|
else:
|
|
self.fake_criterion = self._hinge_fake_loss
|
|
self.real_criterion = self._hinge_real_loss
|
|
|
|
def forward(
|
|
self,
|
|
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Calcualate discriminator adversarial loss.
|
|
|
|
Args:
|
|
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
|
outputs, list of discriminator outputs, or list of list of discriminator
|
|
outputs calculated from generator.
|
|
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
|
outputs, list of discriminator outputs, or list of list of discriminator
|
|
outputs calculated from groundtruth.
|
|
|
|
Returns:
|
|
Tensor: Discriminator real loss value.
|
|
Tensor: Discriminator fake loss value.
|
|
|
|
"""
|
|
if isinstance(outputs, (tuple, list)):
|
|
real_loss = 0.0
|
|
fake_loss = 0.0
|
|
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
|
if isinstance(outputs_hat_, (tuple, list)):
|
|
# NOTE(kan-bayashi): case including feature maps
|
|
outputs_hat_ = outputs_hat_[-1]
|
|
outputs_ = outputs_[-1]
|
|
real_loss += self.real_criterion(outputs_)
|
|
fake_loss += self.fake_criterion(outputs_hat_)
|
|
if self.average_by_discriminators:
|
|
fake_loss /= i + 1
|
|
real_loss /= i + 1
|
|
else:
|
|
real_loss = self.real_criterion(outputs)
|
|
fake_loss = self.fake_criterion(outputs_hat)
|
|
|
|
return real_loss, fake_loss
|
|
|
|
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
return F.mse_loss(x, x.new_ones(x.size()))
|
|
|
|
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
return F.mse_loss(x, x.new_zeros(x.size()))
|
|
|
|
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
|
|
|
|
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
|
|
|
|
|
|
class FeatureMatchLoss(torch.nn.Module):
|
|
"""Feature matching loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
average_by_layers: bool = True,
|
|
average_by_discriminators: bool = True,
|
|
include_final_outputs: bool = False,
|
|
):
|
|
"""Initialize FeatureMatchLoss module.
|
|
|
|
Args:
|
|
average_by_layers (bool): Whether to average the loss by the number
|
|
of layers.
|
|
average_by_discriminators (bool): Whether to average the loss by
|
|
the number of discriminators.
|
|
include_final_outputs (bool): Whether to include the final output of
|
|
each discriminator for loss calculation.
|
|
|
|
"""
|
|
super().__init__()
|
|
self.average_by_layers = average_by_layers
|
|
self.average_by_discriminators = average_by_discriminators
|
|
self.include_final_outputs = include_final_outputs
|
|
|
|
def forward(
|
|
self,
|
|
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
|
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
|
) -> torch.Tensor:
|
|
"""Calculate feature matching loss.
|
|
|
|
Args:
|
|
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
|
discriminator outputs or list of discriminator outputs calcuated
|
|
from generator's outputs.
|
|
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
|
discriminator outputs or list of discriminator outputs calcuated
|
|
from groundtruth..
|
|
|
|
Returns:
|
|
Tensor: Feature matching loss value.
|
|
|
|
"""
|
|
feat_match_loss = 0.0
|
|
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
|
|
feat_match_loss_ = 0.0
|
|
if not self.include_final_outputs:
|
|
feats_hat_ = feats_hat_[:-1]
|
|
feats_ = feats_[:-1]
|
|
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
|
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
|
|
if self.average_by_layers:
|
|
feat_match_loss_ /= j + 1
|
|
feat_match_loss += feat_match_loss_
|
|
if self.average_by_discriminators:
|
|
feat_match_loss /= i + 1
|
|
|
|
return feat_match_loss
|
|
|
|
|
|
class MelSpectrogramLoss(torch.nn.Module):
|
|
"""Mel-spectrogram loss."""
|
|
|
|
def __init__(
|
|
self,
|
|
sampling_rate: int = 22050,
|
|
frame_length: int = 1024, # in samples
|
|
frame_shift: int = 256, # in samples
|
|
n_mels: int = 80,
|
|
use_fft_mag: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.wav_to_mel = Wav2LogFilterBank(
|
|
sampling_rate=sampling_rate,
|
|
frame_length=frame_length / sampling_rate, # in second
|
|
frame_shift=frame_shift / sampling_rate, # in second
|
|
use_fft_mag=use_fft_mag,
|
|
num_filters=n_mels,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
y_hat: torch.Tensor,
|
|
y: torch.Tensor,
|
|
return_mel: bool = False,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
|
"""Calculate Mel-spectrogram loss.
|
|
|
|
Args:
|
|
y_hat (Tensor): Generated waveform tensor (B, 1, T).
|
|
y (Tensor): Groundtruth waveform tensor (B, 1, T).
|
|
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
|
|
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
|
|
waveform.
|
|
|
|
Returns:
|
|
Tensor: Mel-spectrogram loss value.
|
|
|
|
"""
|
|
mel_hat = self.wav_to_mel(y_hat.squeeze(1))
|
|
mel = self.wav_to_mel(y.squeeze(1))
|
|
mel_loss = F.l1_loss(mel_hat, mel)
|
|
|
|
if return_mel:
|
|
return mel_loss, (mel_hat, mel)
|
|
|
|
return mel_loss
|
|
|
|
|
|
# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
|
|
|
|
"""VITS-related loss modules.
|
|
|
|
This code is based on https://github.com/jaywalnut310/vits.
|
|
|
|
"""
|
|
|
|
|
|
class KLDivergenceLoss(torch.nn.Module):
|
|
"""KL divergence loss."""
|
|
|
|
def forward(
|
|
self,
|
|
z_p: torch.Tensor,
|
|
logs_q: torch.Tensor,
|
|
m_p: torch.Tensor,
|
|
logs_p: torch.Tensor,
|
|
z_mask: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Calculate KL divergence loss.
|
|
|
|
Args:
|
|
z_p (Tensor): Flow hidden representation (B, H, T_feats).
|
|
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
|
|
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
|
|
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
|
|
z_mask (Tensor): Mask tensor (B, 1, T_feats).
|
|
|
|
Returns:
|
|
Tensor: KL divergence loss.
|
|
|
|
"""
|
|
z_p = z_p.float()
|
|
logs_q = logs_q.float()
|
|
m_p = m_p.float()
|
|
logs_p = logs_p.float()
|
|
z_mask = z_mask.float()
|
|
kl = logs_p - logs_q - 0.5
|
|
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
|
kl = torch.sum(kl * z_mask)
|
|
loss = kl / torch.sum(z_mask)
|
|
|
|
return loss
|
|
|
|
|
|
class KLDivergenceLossWithoutFlow(torch.nn.Module):
|
|
"""KL divergence loss without flow."""
|
|
|
|
def forward(
|
|
self,
|
|
m_q: torch.Tensor,
|
|
logs_q: torch.Tensor,
|
|
m_p: torch.Tensor,
|
|
logs_p: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Calculate KL divergence loss without flow.
|
|
|
|
Args:
|
|
m_q (Tensor): Posterior encoder projected mean (B, H, T_feats).
|
|
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
|
|
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
|
|
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
|
|
"""
|
|
posterior_norm = D.Normal(m_q, torch.exp(logs_q))
|
|
prior_norm = D.Normal(m_p, torch.exp(logs_p))
|
|
loss = D.kl_divergence(posterior_norm, prior_norm).mean()
|
|
return loss
|