import torch import torch.nn.functional as F from torchaudio.transforms import MelSpectrogram def adversarial_g_loss(y_disc_gen): """Hinge loss""" loss = 0.0 for i in range(len(y_disc_gen)): stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze() loss += stft_loss return loss / len(y_disc_gen) def feature_loss(fmap_r, fmap_gen): loss = 0.0 for i in range(len(fmap_r)): for j in range(len(fmap_r[i])): stft_loss = ( (fmap_r[i][j] - fmap_gen[i][j]).abs() / (fmap_r[i][j].abs().mean()) ).mean() loss += stft_loss return loss / (len(fmap_r) * len(fmap_r[0])) def sim_loss(y_disc_r, y_disc_gen): loss = 0.0 for i in range(len(y_disc_r)): loss += F.mse_loss(y_disc_r[i], y_disc_gen[i]) return loss / len(y_disc_r) def reconstruction_loss(x, x_hat, args, eps=1e-7): # NOTE (lsx): hard-coded now L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss # loss_sisnr = sisnr_loss(G_x, x) # # L += 0.01*loss_sisnr # 2^6=64 -> 2^10=1024 # NOTE (lsx): add 2^11 for i in range(6, 12): # for i in range(5, 12): # Encodec setting s = 2**i melspec = MelSpectrogram( sample_rate=args.sampling_rate, n_fft=max(s, 512), win_length=s, hop_length=s // 4, n_mels=64, wkwargs={"device": x_hat.device}, ).to(x_hat.device) S_x = melspec(x) S_x_hat = melspec(x_hat) l1_loss = (S_x - S_x_hat).abs().mean() l2_loss = ( ((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean( dim=-2 ) ** 0.5 ).mean() alpha = (s / 2) ** 0.5 L += l1_loss + alpha * l2_loss return L def criterion_d( y_disc_r, y_disc_gen, fmap_r_det, fmap_gen_det, y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g, ): """Hinge Loss""" loss = 0.0 loss1 = 0.0 loss2 = 0.0 loss3 = 0.0 for i in range(len(y_disc_r)): loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[i]).mean() for i in range(len(y_df_hat_r)): loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[i]).mean() for i in range(len(y_ds_hat_r)): loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[i]).mean() loss = ( loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / len(y_ds_hat_r) ) / 3.0 return loss def criterion_g( commit_loss, x, G_x, fmap_r, fmap_gen, y_disc_r, y_disc_gen, y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g, y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g, args, ): adv_g_loss = adversarial_g_loss(y_disc_gen) feat_loss = ( feature_loss(fmap_r, fmap_gen) + sim_loss(y_disc_r, y_disc_gen) + feature_loss(fmap_f_r, fmap_f_g) + sim_loss(y_df_hat_r, y_df_hat_g) + feature_loss(fmap_s_r, fmap_s_g) + sim_loss(y_ds_hat_r, y_ds_hat_g) ) / 3.0 rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args) total_loss = ( args.lambda_com * commit_loss + args.lambda_adv * adv_g_loss + args.lambda_feat * feat_loss + args.lambda_rec * rec_loss ) return total_loss, adv_g_loss, feat_loss, rec_loss def adopt_weight(weight, global_step, threshold=0, value=0.0): if global_step < threshold: weight = value return weight def adopt_dis_weight(weight, global_step, threshold=0, value=0.0): if global_step % 3 == 0: weight = value return weight def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args): if last_layer is not None: nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] else: print("last_layer cannot be none") assert 1 == 2 d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 1.0, 1.0).detach() d_weight = d_weight * args.lambda_adv return d_weight def loss_g( codebook_loss, speech, speech_hat, fmap, fmap_hat, y, y_hat, global_step, y_df, y_df_hat, y_ds, y_ds_hat, fmap_f, fmap_f_hat, fmap_s, fmap_s_hat, args=None, ): """ args: codebook_loss: commit loss. speech: ground-truth wav. speech_hat: reconstructed wav. fmap: real stft-D feature map. fmap_hat: fake stft-D feature map. y: real stft-D logits. y_hat: fake stft-D logits. global_step: global training step. y_df: real MPD logits. y_df_hat: fake MPD logits. y_ds: real MSD logits. y_ds_hat: fake MSD logits. fmap_f: real MPD feature map. fmap_f_hat: fake MPD feature map. fmap_s: real MSD feature map. fmap_s_hat: fake MSD feature map. """ rec_loss = reconstruction_loss(speech.contiguous(), speech_hat.contiguous(), args) adv_g_loss = adversarial_g_loss(y_hat) adv_mpd_loss = adversarial_g_loss(y_df_hat) adv_msd_loss = adversarial_g_loss(y_ds_hat) adv_loss = ( adv_g_loss + adv_mpd_loss + adv_msd_loss ) / 3.0 # NOTE(lsx): need to divide by 3? feat_loss = feature_loss( fmap, fmap_hat ) # + sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits? feat_loss_mpd = feature_loss( fmap_f, fmap_f_hat ) # + sim_loss(y_df_hat_r, y_df_hat_g) feat_loss_msd = feature_loss( fmap_s, fmap_s_hat ) # + sim_loss(y_ds_hat_r, y_ds_hat_g) feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0 d_weight = torch.tensor(1.0) disc_factor = adopt_weight( args.lambda_adv, global_step, threshold=args.discriminator_iter_start ) if disc_factor == 0.0: fm_loss_wt = 0 else: fm_loss_wt = args.lambda_feat loss = ( rec_loss + d_weight * disc_factor * adv_loss + fm_loss_wt * feat_loss_tot + args.lambda_com * codebook_loss ) return loss, rec_loss, adv_loss, feat_loss_tot, d_weight def loss_dis( y, y_hat, fmap, fmap_hat, y_df, y_df_hat, fmap_f, fmap_f_hat, y_ds, y_ds_hat, fmap_s, fmap_s_hat, global_step, args, ): disc_factor = adopt_weight( args.lambda_adv, global_step, threshold=args.discriminator_iter_start ) d_loss = disc_factor * criterion_d( y, y_hat, fmap, fmap_hat, y_df, y_df_hat, fmap_f, fmap_f_hat, y_ds, y_ds_hat, fmap_s, fmap_s_hat, ) return d_loss