2024-09-06 21:20:45 +08:00

269 lines
6.8 KiB
Python

import torch
import torch.nn.functional as F
from torchaudio.transforms import MelSpectrogram
def adversarial_g_loss(y_disc_gen):
"""Hinge loss"""
loss = 0.0
for i in range(len(y_disc_gen)):
stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze()
loss += stft_loss
return loss / len(y_disc_gen)
def feature_loss(fmap_r, fmap_gen):
loss = 0.0
for i in range(len(fmap_r)):
for j in range(len(fmap_r[i])):
stft_loss = (
(fmap_r[i][j] - fmap_gen[i][j]).abs() / (fmap_r[i][j].abs().mean())
).mean()
loss += stft_loss
return loss / (len(fmap_r) * len(fmap_r[0]))
def sim_loss(y_disc_r, y_disc_gen):
loss = 0.0
for i in range(len(y_disc_r)):
loss += F.mse_loss(y_disc_r[i], y_disc_gen[i])
return loss / len(y_disc_r)
def reconstruction_loss(x, x_hat, args, eps=1e-7):
# NOTE (lsx): hard-coded now
L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss
# loss_sisnr = sisnr_loss(G_x, x) #
# L += 0.01*loss_sisnr
# 2^6=64 -> 2^10=1024
# NOTE (lsx): add 2^11
for i in range(6, 12):
# for i in range(5, 12): # Encodec setting
s = 2**i
melspec = MelSpectrogram(
sample_rate=args.sampling_rate,
n_fft=max(s, 512),
win_length=s,
hop_length=s // 4,
n_mels=64,
wkwargs={"device": x_hat.device},
).to(x_hat.device)
S_x = melspec(x)
S_x_hat = melspec(x_hat)
l1_loss = (S_x - S_x_hat).abs().mean()
l2_loss = (
((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean(
dim=-2
)
** 0.5
).mean()
alpha = (s / 2) ** 0.5
L += l1_loss + alpha * l2_loss
return L
def criterion_d(
y_disc_r,
y_disc_gen,
fmap_r_det,
fmap_gen_det,
y_df_hat_r,
y_df_hat_g,
fmap_f_r,
fmap_f_g,
y_ds_hat_r,
y_ds_hat_g,
fmap_s_r,
fmap_s_g,
):
"""Hinge Loss"""
loss = 0.0
loss1 = 0.0
loss2 = 0.0
loss3 = 0.0
for i in range(len(y_disc_r)):
loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[i]).mean()
for i in range(len(y_df_hat_r)):
loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[i]).mean()
for i in range(len(y_ds_hat_r)):
loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[i]).mean()
loss = (
loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / len(y_ds_hat_r)
) / 3.0
return loss
def criterion_g(
commit_loss,
x,
G_x,
fmap_r,
fmap_gen,
y_disc_r,
y_disc_gen,
y_df_hat_r,
y_df_hat_g,
fmap_f_r,
fmap_f_g,
y_ds_hat_r,
y_ds_hat_g,
fmap_s_r,
fmap_s_g,
args,
):
adv_g_loss = adversarial_g_loss(y_disc_gen)
feat_loss = (
feature_loss(fmap_r, fmap_gen)
+ sim_loss(y_disc_r, y_disc_gen)
+ feature_loss(fmap_f_r, fmap_f_g)
+ sim_loss(y_df_hat_r, y_df_hat_g)
+ feature_loss(fmap_s_r, fmap_s_g)
+ sim_loss(y_ds_hat_r, y_ds_hat_g)
) / 3.0
rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args)
total_loss = (
args.lambda_com * commit_loss
+ args.lambda_adv * adv_g_loss
+ args.lambda_feat * feat_loss
+ args.lambda_rec * rec_loss
)
return total_loss, adv_g_loss, feat_loss, rec_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
def adopt_dis_weight(weight, global_step, threshold=0, value=0.0):
if global_step % 3 == 0:
weight = value
return weight
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
print("last_layer cannot be none")
assert 1 == 2
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 1.0, 1.0).detach()
d_weight = d_weight * args.lambda_adv
return d_weight
def loss_g(
codebook_loss,
speech,
speech_hat,
fmap,
fmap_hat,
y,
y_hat,
global_step,
y_df,
y_df_hat,
y_ds,
y_ds_hat,
fmap_f,
fmap_f_hat,
fmap_s,
fmap_s_hat,
args=None,
):
"""
args:
codebook_loss: commit loss.
speech: ground-truth wav.
speech_hat: reconstructed wav.
fmap: real stft-D feature map.
fmap_hat: fake stft-D feature map.
y: real stft-D logits.
y_hat: fake stft-D logits.
global_step: global training step.
y_df: real MPD logits.
y_df_hat: fake MPD logits.
y_ds: real MSD logits.
y_ds_hat: fake MSD logits.
fmap_f: real MPD feature map.
fmap_f_hat: fake MPD feature map.
fmap_s: real MSD feature map.
fmap_s_hat: fake MSD feature map.
"""
rec_loss = reconstruction_loss(speech.contiguous(), speech_hat.contiguous(), args)
adv_g_loss = adversarial_g_loss(y_hat)
adv_mpd_loss = adversarial_g_loss(y_df_hat)
adv_msd_loss = adversarial_g_loss(y_ds_hat)
adv_loss = (
adv_g_loss + adv_mpd_loss + adv_msd_loss
) / 3.0 # NOTE(lsx): need to divide by 3?
feat_loss = feature_loss(
fmap, fmap_hat
) # + sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits?
feat_loss_mpd = feature_loss(
fmap_f, fmap_f_hat
) # + sim_loss(y_df_hat_r, y_df_hat_g)
feat_loss_msd = feature_loss(
fmap_s, fmap_s_hat
) # + sim_loss(y_ds_hat_r, y_ds_hat_g)
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
d_weight = torch.tensor(1.0)
disc_factor = adopt_weight(
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
)
if disc_factor == 0.0:
fm_loss_wt = 0
else:
fm_loss_wt = args.lambda_feat
loss = (
rec_loss
+ d_weight * disc_factor * adv_loss
+ fm_loss_wt * feat_loss_tot
+ args.lambda_com * codebook_loss
)
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
def loss_dis(
y,
y_hat,
fmap,
fmap_hat,
y_df,
y_df_hat,
fmap_f,
fmap_f_hat,
y_ds,
y_ds_hat,
fmap_s,
fmap_s_hat,
global_step,
args,
):
disc_factor = adopt_weight(
args.lambda_adv, global_step, threshold=args.discriminator_iter_start
)
d_loss = disc_factor * criterion_d(
y,
y_hat,
fmap,
fmap_hat,
y_df,
y_df_hat,
fmap_f,
fmap_f_hat,
y_ds,
y_ds_hat,
fmap_s,
fmap_s_hat,
)
return d_loss