mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
481 lines
15 KiB
Python
481 lines
15 KiB
Python
from typing import List, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from lhotse.features.kaldi import Wav2LogFilterBank
|
|
from torchaudio.transforms import MelSpectrogram
|
|
|
|
|
|
class GeneratorAdversarialLoss(torch.nn.Module):
|
|
"""Generator adversarial loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
average_by_discriminators: bool = True,
|
|
loss_type: str = "hinge",
|
|
):
|
|
"""Initialize GeneratorAversarialLoss module.
|
|
|
|
Args:
|
|
average_by_discriminators (bool): Whether to average the loss by
|
|
the number of discriminators.
|
|
loss_type (str): Loss type, "mse" or "hinge".
|
|
|
|
"""
|
|
super().__init__()
|
|
self.average_by_discriminators = average_by_discriminators
|
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
|
if loss_type == "mse":
|
|
self.criterion = self._mse_loss
|
|
else:
|
|
self.criterion = self._hinge_loss
|
|
|
|
def forward(
|
|
self,
|
|
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
) -> torch.Tensor:
|
|
"""Calcualate generator adversarial loss.
|
|
|
|
Args:
|
|
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
|
outputs, list of discriminator outputs, or list of list of discriminator
|
|
outputs..
|
|
|
|
Returns:
|
|
Tensor: Generator adversarial loss value.
|
|
|
|
"""
|
|
adv_loss = 0.0
|
|
if isinstance(outputs, (tuple, list)):
|
|
for i, outputs_ in enumerate(outputs):
|
|
if isinstance(outputs_, (tuple, list)):
|
|
# NOTE(kan-bayashi): case including feature maps
|
|
outputs_ = outputs_[-1]
|
|
adv_loss += self.criterion(outputs_)
|
|
if self.average_by_discriminators:
|
|
adv_loss /= i + 1
|
|
else:
|
|
for i, outputs_ in enumerate(outputs):
|
|
adv_loss += self.criterion(outputs_)
|
|
adv_loss /= i + 1
|
|
return adv_loss
|
|
|
|
def _mse_loss(self, x):
|
|
return F.mse_loss(x, x.new_ones(x.size()))
|
|
|
|
def _hinge_loss(self, x):
|
|
return F.relu(1 - x).mean()
|
|
|
|
|
|
class DiscriminatorAdversarialLoss(torch.nn.Module):
|
|
"""Discriminator adversarial loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
average_by_discriminators: bool = True,
|
|
loss_type: str = "hinge",
|
|
):
|
|
"""Initialize DiscriminatorAversarialLoss module.
|
|
|
|
Args:
|
|
average_by_discriminators (bool): Whether to average the loss by
|
|
the number of discriminators.
|
|
loss_type (str): Loss type, "mse" or "hinge".
|
|
|
|
"""
|
|
super().__init__()
|
|
self.average_by_discriminators = average_by_discriminators
|
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
|
if loss_type == "mse":
|
|
self.fake_criterion = self._mse_fake_loss
|
|
self.real_criterion = self._mse_real_loss
|
|
else:
|
|
self.fake_criterion = self._hinge_fake_loss
|
|
self.real_criterion = self._hinge_real_loss
|
|
|
|
def forward(
|
|
self,
|
|
outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Calcualate discriminator adversarial loss.
|
|
|
|
Args:
|
|
outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
|
outputs, list of discriminator outputs, or list of list of discriminator
|
|
outputs calculated from generator.
|
|
outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
|
|
outputs, list of discriminator outputs, or list of list of discriminator
|
|
outputs calculated from groundtruth.
|
|
|
|
Returns:
|
|
Tensor: Discriminator real loss value.
|
|
Tensor: Discriminator fake loss value.
|
|
|
|
"""
|
|
real_loss = 0.0
|
|
fake_loss = 0.0
|
|
if isinstance(outputs, (tuple, list)):
|
|
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
|
if isinstance(outputs_hat_, (tuple, list)):
|
|
# NOTE(kan-bayashi): case including feature maps
|
|
outputs_hat_ = outputs_hat_[-1]
|
|
outputs_ = outputs_[-1]
|
|
real_loss += self.real_criterion(outputs_)
|
|
fake_loss += self.fake_criterion(outputs_hat_)
|
|
if self.average_by_discriminators:
|
|
fake_loss /= i + 1
|
|
real_loss /= i + 1
|
|
else:
|
|
for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
|
|
real_loss += self.real_criterion(outputs_)
|
|
fake_loss += self.fake_criterion(outputs_hat_)
|
|
fake_loss /= i + 1
|
|
real_loss /= i + 1
|
|
|
|
return real_loss, fake_loss
|
|
|
|
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
return F.mse_loss(x, x.new_ones(x.size()))
|
|
|
|
def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
return F.mse_loss(x, x.new_zeros(x.size()))
|
|
|
|
def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
return F.relu(x.new_ones(x.size()) - x).mean()
|
|
|
|
def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
return F.relu(x.new_ones(x.size()) + x).mean()
|
|
|
|
|
|
class FeatureLoss(torch.nn.Module):
|
|
"""Feature loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
average_by_layers: bool = True,
|
|
average_by_discriminators: bool = True,
|
|
include_final_outputs: bool = True,
|
|
):
|
|
"""Initialize FeatureMatchLoss module.
|
|
|
|
Args:
|
|
average_by_layers (bool): Whether to average the loss by the number
|
|
of layers.
|
|
average_by_discriminators (bool): Whether to average the loss by
|
|
the number of discriminators.
|
|
include_final_outputs (bool): Whether to include the final output of
|
|
each discriminator for loss calculation.
|
|
|
|
"""
|
|
super().__init__()
|
|
self.average_by_layers = average_by_layers
|
|
self.average_by_discriminators = average_by_discriminators
|
|
self.include_final_outputs = include_final_outputs
|
|
|
|
def forward(
|
|
self,
|
|
feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
|
feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
|
|
) -> torch.Tensor:
|
|
"""Calculate feature matching loss.
|
|
|
|
Args:
|
|
feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
|
discriminator outputs or list of discriminator outputs calcuated
|
|
from generator's outputs.
|
|
feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
|
|
discriminator outputs or list of discriminator outputs calcuated
|
|
from groundtruth..
|
|
|
|
Returns:
|
|
Tensor: Feature matching loss value.
|
|
|
|
"""
|
|
feat_match_loss = 0.0
|
|
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
|
|
feat_match_loss_ = 0.0
|
|
if not self.include_final_outputs:
|
|
feats_hat_ = feats_hat_[:-1]
|
|
feats_ = feats_[:-1]
|
|
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
|
feat_match_loss_ += (
|
|
(feat_hat_ - feat_).abs() / (feat_.abs().mean())
|
|
).mean()
|
|
if self.average_by_layers:
|
|
feat_match_loss_ /= j + 1
|
|
feat_match_loss += feat_match_loss_
|
|
if self.average_by_discriminators:
|
|
feat_match_loss /= i + 1
|
|
|
|
return feat_match_loss
|
|
|
|
|
|
class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
|
"""Mel Spec Reconstruction loss."""
|
|
|
|
def __init__(
|
|
self,
|
|
sampling_rate: int = 22050,
|
|
n_mels: int = 64,
|
|
use_fft_mag: bool = True,
|
|
return_mel: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.wav_to_specs = []
|
|
for i in range(5, 12):
|
|
s = 2**i
|
|
# self.wav_to_specs.append(
|
|
# Wav2LogFilterBank(
|
|
# sampling_rate=sampling_rate,
|
|
# frame_length=s,
|
|
# frame_shift=s // 4,
|
|
# use_fft_mag=use_fft_mag,
|
|
# num_filters=n_mels,
|
|
# )
|
|
# )
|
|
self.wav_to_specs.append(
|
|
MelSpectrogram(
|
|
sample_rate=sampling_rate,
|
|
n_fft=max(s, 512),
|
|
win_length=s,
|
|
hop_length=s // 4,
|
|
n_mels=n_mels,
|
|
)
|
|
)
|
|
self.return_mel = return_mel
|
|
|
|
def forward(
|
|
self,
|
|
x_hat: torch.Tensor,
|
|
x: torch.Tensor,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
|
|
"""Calculate Mel-spectrogram loss.
|
|
|
|
Args:
|
|
x_hat (Tensor): Generated waveform tensor (B, 1, T).
|
|
x (Tensor): Groundtruth waveform tensor (B, 1, T).
|
|
spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
|
|
(B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
|
|
waveform.
|
|
|
|
Returns:
|
|
Tensor: Mel-spectrogram loss value.
|
|
|
|
"""
|
|
mel_loss = 0.0
|
|
|
|
for i, wav_to_spec in enumerate(self.wav_to_specs):
|
|
s = 2 ** (i + 5)
|
|
wav_to_spec.to(x.device)
|
|
|
|
mel_hat = wav_to_spec(x_hat.squeeze(1))
|
|
mel = wav_to_spec(x.squeeze(1))
|
|
|
|
mel_loss += F.l1_loss(
|
|
mel_hat, mel, reduce=True, reduction="mean"
|
|
) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean")
|
|
|
|
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
|
|
# mel = self.wav_to_spec(x.squeeze(1))
|
|
# mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel)
|
|
|
|
if self.return_mel:
|
|
return mel_loss, (mel_hat, mel)
|
|
|
|
return mel_loss
|
|
|
|
|
|
class WavReconstructionLoss(torch.nn.Module):
|
|
"""Wav Reconstruction loss."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(
|
|
self,
|
|
x_hat: torch.Tensor,
|
|
x: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Calculate wav loss.
|
|
|
|
Args:
|
|
x_hat (Tensor): Generated waveform tensor (B, 1, T).
|
|
x (Tensor): Groundtruth waveform tensor (B, 1, T).
|
|
|
|
Returns:
|
|
Tensor: Wav loss value.
|
|
|
|
"""
|
|
wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean")
|
|
|
|
return wav_loss
|
|
|
|
|
|
def adversarial_g_loss(y_disc_gen):
|
|
"""Hinge loss"""
|
|
loss = 0.0
|
|
for i in range(len(y_disc_gen)):
|
|
stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze()
|
|
loss += stft_loss
|
|
return loss / len(y_disc_gen)
|
|
|
|
|
|
def feature_loss(fmap_r, fmap_gen):
|
|
loss = 0.0
|
|
for i in range(len(fmap_r)):
|
|
for j in range(len(fmap_r[i])):
|
|
stft_loss = (
|
|
(fmap_r[i][j] - fmap_gen[i][j]).abs() / (fmap_r[i][j].abs().mean())
|
|
).mean()
|
|
loss += stft_loss
|
|
return loss / (len(fmap_r) * len(fmap_r[0]))
|
|
|
|
|
|
def sim_loss(y_disc_r, y_disc_gen):
|
|
loss = 0.0
|
|
for i in range(len(y_disc_r)):
|
|
loss += F.mse_loss(y_disc_r[i], y_disc_gen[i])
|
|
return loss / len(y_disc_r)
|
|
|
|
|
|
def reconstruction_loss(x, x_hat, args, eps=1e-7):
|
|
# NOTE (lsx): hard-coded now
|
|
L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss
|
|
# loss_sisnr = sisnr_loss(G_x, x) #
|
|
# L += 0.01*loss_sisnr
|
|
# 2^6=64 -> 2^10=1024
|
|
# NOTE (lsx): add 2^11
|
|
for i in range(6, 12):
|
|
# for i in range(5, 12): # Encodec setting
|
|
s = 2**i
|
|
melspec = MelSpectrogram(
|
|
sample_rate=args.sampling_rate,
|
|
n_fft=max(s, 512),
|
|
win_length=s,
|
|
hop_length=s // 4,
|
|
n_mels=64,
|
|
wkwargs={"device": x_hat.device},
|
|
).to(x_hat.device)
|
|
S_x = melspec(x)
|
|
S_x_hat = melspec(x_hat)
|
|
l1_loss = (S_x - S_x_hat).abs().mean()
|
|
l2_loss = (
|
|
((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean(
|
|
dim=-2
|
|
)
|
|
** 0.5
|
|
).mean()
|
|
|
|
alpha = (s / 2) ** 0.5
|
|
L += l1_loss + alpha * l2_loss
|
|
return L
|
|
|
|
|
|
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
|
if global_step < threshold:
|
|
weight = value
|
|
return weight
|
|
|
|
|
|
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
|
|
if last_layer is not None:
|
|
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
|
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
|
else:
|
|
print("last_layer cannot be none")
|
|
assert 1 == 2
|
|
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
|
d_weight = torch.clamp(d_weight, 1.0, 1.0).detach()
|
|
d_weight = d_weight * args.lambda_adv
|
|
return d_weight
|
|
|
|
|
|
def loss_g(
|
|
codebook_loss,
|
|
speech,
|
|
speech_hat,
|
|
fmap,
|
|
fmap_hat,
|
|
y,
|
|
y_hat,
|
|
y_df,
|
|
y_df_hat,
|
|
y_ds,
|
|
y_ds_hat,
|
|
fmap_f,
|
|
fmap_f_hat,
|
|
fmap_s,
|
|
fmap_s_hat,
|
|
args=None,
|
|
):
|
|
"""
|
|
args:
|
|
codebook_loss: commit loss.
|
|
speech: ground-truth wav.
|
|
speech_hat: reconstructed wav.
|
|
fmap: real stft-D feature map.
|
|
fmap_hat: fake stft-D feature map.
|
|
y: real stft-D logits.
|
|
y_hat: fake stft-D logits.
|
|
global_step: global training step.
|
|
y_df: real MPD logits.
|
|
y_df_hat: fake MPD logits.
|
|
y_ds: real MSD logits.
|
|
y_ds_hat: fake MSD logits.
|
|
fmap_f: real MPD feature map.
|
|
fmap_f_hat: fake MPD feature map.
|
|
fmap_s: real MSD feature map.
|
|
fmap_s_hat: fake MSD feature map.
|
|
"""
|
|
rec_loss = reconstruction_loss(speech.contiguous(), speech_hat.contiguous(), args)
|
|
adv_g_loss = adversarial_g_loss(y_hat)
|
|
adv_mpd_loss = adversarial_g_loss(y_df_hat)
|
|
adv_msd_loss = adversarial_g_loss(y_ds_hat)
|
|
adv_loss = (
|
|
adv_g_loss + adv_mpd_loss + adv_msd_loss
|
|
) / 3.0 # NOTE(lsx): need to divide by 3?
|
|
feat_loss = feature_loss(
|
|
fmap, fmap_hat
|
|
) # + sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits?
|
|
feat_loss_mpd = feature_loss(
|
|
fmap_f, fmap_f_hat
|
|
) # + sim_loss(y_df_hat_r, y_df_hat_g)
|
|
feat_loss_msd = feature_loss(
|
|
fmap_s, fmap_s_hat
|
|
) # + sim_loss(y_ds_hat_r, y_ds_hat_g)
|
|
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
|
|
d_weight = torch.tensor(1.0)
|
|
|
|
# disc_factor = adopt_weight(
|
|
# args.lambda_adv, global_step, threshold=args.discriminator_iter_start
|
|
# )
|
|
disc_factor = 1
|
|
if disc_factor == 0.0:
|
|
fm_loss_wt = 0
|
|
else:
|
|
fm_loss_wt = args.lambda_feat
|
|
|
|
loss = (
|
|
rec_loss
|
|
+ d_weight * disc_factor * adv_loss
|
|
+ fm_loss_wt * feat_loss_tot
|
|
+ args.lambda_com * codebook_loss
|
|
)
|
|
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# la = FeatureLoss(average_by_layers=True, average_by_discriminators=True)
|
|
# aa = [torch.rand(192, 192) for _ in range(3)]
|
|
# bb = [torch.rand(192, 192) for _ in range(3)]
|
|
# print(la(bb, aa))
|
|
# print(feature_loss(aa, bb))
|
|
la = GeneratorAdversarialLoss(average_by_discriminators=True, loss_type="hinge")
|
|
aa = torch.Tensor([0.1, 0.2, 0.3, 0.4])
|
|
bb = torch.Tensor([0.4, 0.3, 0.2, 0.1])
|
|
print(la(aa))
|
|
print(adversarial_g_loss(aa))
|
|
print(la(bb))
|
|
print(adversarial_g_loss(bb))
|