2024-09-06 18:07:50 +08:00

299 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn.functional as F
from torchaudio.transforms import MelSpectrogram
def adversarial_g_loss(y_disc_gen):
"""Hinge loss"""
loss = 0.0
for i in range(len(y_disc_gen)):
stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze()
loss += stft_loss
return loss / len(y_disc_gen)
def feature_loss(fmap_r, fmap_gen):
loss = 0.0
for i in range(len(fmap_r)):
for j in range(len(fmap_r[i])):
stft_loss = (
(fmap_r[i][j] - fmap_gen[i][j]).abs() / (fmap_r[i][j].abs().mean())
).mean()
loss += stft_loss
return loss / (len(fmap_r) * len(fmap_r[0]))
def sim_loss(y_disc_r, y_disc_gen):
loss = 0.0
for i in range(len(y_disc_r)):
loss += F.mse_loss(y_disc_r[i], y_disc_gen[i])
return loss / len(y_disc_r)
# def sisnr_loss(x, s, eps=1e-8):
# """
# calculate training loss
# input:
# x: separated signal, N x S tensor, estimate value
# s: reference signal, N x S tensor, True value
# Return:
# sisnr: N tensor
# """
# if x.shape != s.shape:
# if x.shape[-1] > s.shape[-1]:
# x = x[:, :s.shape[-1]]
# else:
# s = s[:, :x.shape[-1]]
# def l2norm(mat, keepdim=False):
# return torch.norm(mat, dim=-1, keepdim=keepdim)
# if x.shape != s.shape:
# raise RuntimeError(
# "Dimention mismatch when calculate si-snr, {} vs {}".format(
# x.shape, s.shape))
# x_zm = x - torch.mean(x, dim=-1, keepdim=True)
# s_zm = s - torch.mean(s, dim=-1, keepdim=True)
# t = torch.sum(
# x_zm * s_zm, dim=-1,
# keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
# loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))
# return torch.sum(loss) / x.shape[0]
def reconstruction_loss(x, x_hat, args, eps=1e-7):
# NOTE (lsx): hard-coded now
L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss
# loss_sisnr = sisnr_loss(G_x, x) #
# L += 0.01*loss_sisnr
# 2^6=64 -> 2^10=1024
# NOTE (lsx): add 2^11
for i in range(6, 12):
# for i in range(5, 12): # Encodec setting
s = 2**i
melspec = MelSpectrogram(
sample_rate=args.sampling_rate,
n_fft=max(s, 512),
win_length=s,
hop_length=s // 4,
n_mels=64,
wkwargs={"device": x_hat.device},
).to(x_hat.device)
S_x = melspec(x)
S_x_hat = melspec(x_hat)
l1_loss = (S_x - S_x_hat).abs().mean()
l2_loss = (
((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean(
dim=-2
)
** 0.5
).mean()
alpha = (s / 2) ** 0.5
L += l1_loss + alpha * l2_loss
return L
def criterion_d(
y_disc_r,
y_disc_gen,
fmap_r_det,
fmap_gen_det,
y_df_hat_r,
y_df_hat_g,
fmap_f_r,
fmap_f_g,
y_ds_hat_r,
y_ds_hat_g,
fmap_s_r,
fmap_s_g,
):
"""Hinge Loss"""
loss = 0.0
loss1 = 0.0
loss2 = 0.0
loss3 = 0.0
for i in range(len(y_disc_r)):
loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[i]).mean()
for i in range(len(y_df_hat_r)):
loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[i]).mean()
for i in range(len(y_ds_hat_r)):
loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[i]).mean()
loss = (
loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / len(y_ds_hat_r)
) / 3.0
return loss
def criterion_g(
commit_loss,
x,
G_x,
fmap_r,
fmap_gen,
y_disc_r,
y_disc_gen,
y_df_hat_r,
y_df_hat_g,
fmap_f_r,
fmap_f_g,
y_ds_hat_r,
y_ds_hat_g,
fmap_s_r,
fmap_s_g,
args,
):
adv_g_loss = adversarial_g_loss(y_disc_gen)
feat_loss = (
feature_loss(fmap_r, fmap_gen)
+ sim_loss(y_disc_r, y_disc_gen)
+ feature_loss(fmap_f_r, fmap_f_g)
+ sim_loss(y_df_hat_r, y_df_hat_g)
+ feature_loss(fmap_s_r, fmap_s_g)
+ sim_loss(y_ds_hat_r, y_ds_hat_g)
) / 3.0
rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args)
total_loss = (
args.lambda_com * commit_loss
+ args.lambda_adv * adv_g_loss
+ args.lambda_feat * feat_loss
+ args.lambda_rec * rec_loss
)
return total_loss, adv_g_loss, feat_loss, rec_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def adopt_dis_weight(weight, global_step, threshold=0, value=0.0):
# 0,3,6,9,13....这些时间步不更新dis
if global_step % 3 == 0:
weight = value
return weight
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
print("last_layer cannot be none")
assert 1 == 2
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 1.0, 1.0).detach()
d_weight = d_weight * args.lambda_adv
return d_weight
def loss_g(
codebook_loss,
speech,
speech_hat,
fmap,
fmap_hat,
y,
y_hat,
global_step,
y_df,
y_df_hat,
y_ds,
y_ds_hat,
fmap_f,
fmap_f_hat,
fmap_s,
fmap_s_hat,
args=None,
):
"""
args:
codebook_loss: commit loss.
speech: ground-truth wav.
speech_hat: reconstructed wav.
fmap: real stft-D feature map.
fmap_hat: fake stft-D feature map.
y: real stft-D logits.
y_hat: fake stft-D logits.
global_step: global training step.
y_df: real MPD logits.
y_df_hat: fake MPD logits.
y_ds: real MSD logits.
y_ds_hat: fake MSD logits.
fmap_f: real MPD feature map.
fmap_f_hat: fake MPD feature map.
fmap_s: real MSD feature map.
fmap_s_hat: fake MSD feature map.
"""
rec_loss = reconstruction_loss(speech.contiguous(), speech_hat.contiguous(), args)
adv_g_loss = adversarial_g_loss(y_hat)
adv_mpd_loss = adversarial_g_loss(y_df_hat)
adv_msd_loss = adversarial_g_loss(y_ds_hat)
adv_loss = (
adv_g_loss + adv_mpd_loss + adv_msd_loss
) / 3.0 # NOTE(lsx): need to divide by 3?
feat_loss = feature_loss(
fmap, fmap_hat
) # + sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits?
feat_loss_mpd = feature_loss(
fmap_f, fmap_f_hat
) # + sim_loss(y_df_hat_r, y_df_hat_g)
feat_loss_msd = feature_loss(
fmap_s, fmap_s_hat
) # + sim_loss(y_ds_hat_r, y_ds_hat_g)
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
d_weight = torch.tensor(1.0)
disc_factor = adopt_weight(
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
)
if disc_factor == 0.0:
fm_loss_wt = 0
else:
fm_loss_wt = args.lambda_feat
loss = (
rec_loss
+ d_weight * disc_factor * adv_loss
+ fm_loss_wt * feat_loss_tot
+ args.lambda_com * codebook_loss
)
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
def loss_dis(
y,
y_hat,
fmap,
fmap_hat,
y_df,
y_df_hat,
fmap_f,
fmap_f_hat,
y_ds,
y_ds_hat,
fmap_s,
fmap_s_hat,
global_step,
args,
):
disc_factor = adopt_weight(
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
)
d_loss = disc_factor * criterion_d(
y,
y_hat,
fmap,
fmap_hat,
y_df,
y_df_hat,
fmap_f,
fmap_f_hat,
y_ds,
y_ds_hat,
fmap_s,
fmap_s_hat,
)
return d_loss