mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
minor updates
This commit is contained in:
parent
43267e3e29
commit
2356621059
@ -157,27 +157,7 @@ class Encodec(nn.Module):
|
||||
x=speech, x_hat=speech_hat
|
||||
)
|
||||
|
||||
# loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
|
||||
# commit_loss,
|
||||
# speech,
|
||||
# speech_hat,
|
||||
# fmap,
|
||||
# fmap_hat,
|
||||
# y,
|
||||
# y_hat,
|
||||
# y_p,
|
||||
# y_p_hat,
|
||||
# y_s,
|
||||
# y_s_hat,
|
||||
# fmap_p,
|
||||
# fmap_p_hat,
|
||||
# fmap_s,
|
||||
# fmap_s_hat,
|
||||
# args=self.params,
|
||||
# )
|
||||
|
||||
stats = dict(
|
||||
# generator_loss=loss.item(),
|
||||
generator_wav_reconstruction_loss=wav_reconstruction_loss.item(),
|
||||
generator_mel_reconstruction_loss=mel_reconstruction_loss.item(),
|
||||
generator_feature_stft_loss=feature_stft_loss.item(),
|
||||
@ -187,7 +167,6 @@ class Encodec(nn.Module):
|
||||
generator_period_adv_loss=gen_period_adv_loss.item(),
|
||||
generator_scale_adv_loss=gen_scale_adv_loss.item(),
|
||||
generator_commit_loss=commit_loss.item(),
|
||||
# d_weight=d_weight.item(),
|
||||
)
|
||||
|
||||
if return_sample:
|
||||
@ -260,18 +239,16 @@ class Encodec(nn.Module):
|
||||
speech_hat.contiguous().detach()
|
||||
)
|
||||
|
||||
disc_period_real_adv_loss, disc_period_fake_adv_loss = torch.tensor(
|
||||
0.0
|
||||
), torch.tensor(0.0)
|
||||
disc_period_real_adv_loss = torch.tensor(0.0)
|
||||
disc_period_fake_adv_loss = torch.tensor(0.0)
|
||||
if self.multi_period_discriminator is not None:
|
||||
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
||||
speech.contiguous(),
|
||||
speech_hat.contiguous().detach(),
|
||||
)
|
||||
|
||||
disc_scale_real_adv_loss, disc_scale_fake_adv_loss = torch.tensor(
|
||||
0.0
|
||||
), torch.tensor(0.0)
|
||||
disc_scale_real_adv_loss = torch.tensor(0.0)
|
||||
disc_scale_fake_adv_loss = torch.tensor(0.0)
|
||||
if self.multi_scale_discriminator is not None:
|
||||
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
||||
speech.contiguous(),
|
||||
|
@ -317,171 +317,3 @@ class WavReconstructionLoss(torch.nn.Module):
|
||||
wav_loss = F.l1_loss(x, x_hat)
|
||||
|
||||
return wav_loss
|
||||
|
||||
|
||||
def adversarial_g_loss(y_disc_gen):
|
||||
"""Hinge loss"""
|
||||
loss = 0.0
|
||||
for i in range(len(y_disc_gen)):
|
||||
stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze()
|
||||
loss += stft_loss
|
||||
return loss / len(y_disc_gen)
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_gen):
|
||||
loss = 0.0
|
||||
for i in range(len(fmap_r)):
|
||||
for j in range(len(fmap_r[i])):
|
||||
stft_loss = (
|
||||
(fmap_r[i][j] - fmap_gen[i][j]).abs() / (fmap_r[i][j].abs().mean())
|
||||
).mean()
|
||||
loss += stft_loss
|
||||
return loss / (len(fmap_r) * len(fmap_r[0]))
|
||||
|
||||
|
||||
def sim_loss(y_disc_r, y_disc_gen):
|
||||
loss = 0.0
|
||||
for i in range(len(y_disc_r)):
|
||||
loss += F.mse_loss(y_disc_r[i], y_disc_gen[i])
|
||||
return loss / len(y_disc_r)
|
||||
|
||||
|
||||
def reconstruction_loss(x, x_hat, args, eps=1e-7):
|
||||
# NOTE (lsx): hard-coded now
|
||||
L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss
|
||||
# loss_sisnr = sisnr_loss(G_x, x) #
|
||||
# L += 0.01*loss_sisnr
|
||||
# 2^6=64 -> 2^10=1024
|
||||
# NOTE (lsx): add 2^11
|
||||
for i in range(6, 12):
|
||||
# for i in range(5, 12): # Encodec setting
|
||||
s = 2**i
|
||||
melspec = MelSpectrogram(
|
||||
sample_rate=args.sampling_rate,
|
||||
n_fft=max(s, 512),
|
||||
win_length=s,
|
||||
hop_length=s // 4,
|
||||
n_mels=64,
|
||||
wkwargs={"device": x_hat.device},
|
||||
).to(x_hat.device)
|
||||
S_x = melspec(x)
|
||||
S_x_hat = melspec(x_hat)
|
||||
l1_loss = (S_x - S_x_hat).abs().mean()
|
||||
l2_loss = (
|
||||
((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean(
|
||||
dim=-2
|
||||
)
|
||||
** 0.5
|
||||
).mean()
|
||||
|
||||
alpha = (s / 2) ** 0.5
|
||||
L += l1_loss + alpha * l2_loss
|
||||
return L
|
||||
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
|
||||
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
|
||||
if last_layer is not None:
|
||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||
else:
|
||||
print("last_layer cannot be none")
|
||||
assert 1 == 2
|
||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||
d_weight = torch.clamp(d_weight, 1.0, 1.0).detach()
|
||||
d_weight = d_weight * args.lambda_adv
|
||||
return d_weight
|
||||
|
||||
|
||||
def loss_g(
|
||||
codebook_loss,
|
||||
speech,
|
||||
speech_hat,
|
||||
fmap,
|
||||
fmap_hat,
|
||||
y,
|
||||
y_hat,
|
||||
y_df,
|
||||
y_df_hat,
|
||||
y_ds,
|
||||
y_ds_hat,
|
||||
fmap_f,
|
||||
fmap_f_hat,
|
||||
fmap_s,
|
||||
fmap_s_hat,
|
||||
args=None,
|
||||
):
|
||||
"""
|
||||
args:
|
||||
codebook_loss: commit loss.
|
||||
speech: ground-truth wav.
|
||||
speech_hat: reconstructed wav.
|
||||
fmap: real stft-D feature map.
|
||||
fmap_hat: fake stft-D feature map.
|
||||
y: real stft-D logits.
|
||||
y_hat: fake stft-D logits.
|
||||
global_step: global training step.
|
||||
y_df: real MPD logits.
|
||||
y_df_hat: fake MPD logits.
|
||||
y_ds: real MSD logits.
|
||||
y_ds_hat: fake MSD logits.
|
||||
fmap_f: real MPD feature map.
|
||||
fmap_f_hat: fake MPD feature map.
|
||||
fmap_s: real MSD feature map.
|
||||
fmap_s_hat: fake MSD feature map.
|
||||
"""
|
||||
rec_loss = reconstruction_loss(speech.contiguous(), speech_hat.contiguous(), args)
|
||||
adv_g_loss = adversarial_g_loss(y_hat)
|
||||
adv_mpd_loss = adversarial_g_loss(y_df_hat)
|
||||
adv_msd_loss = adversarial_g_loss(y_ds_hat)
|
||||
adv_loss = (
|
||||
adv_g_loss + adv_mpd_loss + adv_msd_loss
|
||||
) / 3.0 # NOTE(lsx): need to divide by 3?
|
||||
feat_loss = feature_loss(
|
||||
fmap, fmap_hat
|
||||
) # + sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits?
|
||||
feat_loss_mpd = feature_loss(
|
||||
fmap_f, fmap_f_hat
|
||||
) # + sim_loss(y_df_hat_r, y_df_hat_g)
|
||||
feat_loss_msd = feature_loss(
|
||||
fmap_s, fmap_s_hat
|
||||
) # + sim_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
|
||||
d_weight = torch.tensor(1.0)
|
||||
|
||||
# disc_factor = adopt_weight(
|
||||
# args.lambda_adv, global_step, threshold=args.discriminator_iter_start
|
||||
# )
|
||||
disc_factor = 1
|
||||
if disc_factor == 0.0:
|
||||
fm_loss_wt = 0
|
||||
else:
|
||||
fm_loss_wt = args.lambda_feat
|
||||
|
||||
loss = (
|
||||
rec_loss
|
||||
+ d_weight * disc_factor * adv_loss
|
||||
+ fm_loss_wt * feat_loss_tot
|
||||
+ args.lambda_com * codebook_loss
|
||||
)
|
||||
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# la = FeatureLoss(average_by_layers=True, average_by_discriminators=True)
|
||||
# aa = [torch.rand(192, 192) for _ in range(3)]
|
||||
# bb = [torch.rand(192, 192) for _ in range(3)]
|
||||
# print(la(bb, aa))
|
||||
# print(feature_loss(aa, bb))
|
||||
la = GeneratorAdversarialLoss(average_by_discriminators=True, loss_type="hinge")
|
||||
aa = torch.Tensor([0.1, 0.2, 0.3, 0.4])
|
||||
bb = torch.Tensor([0.4, 0.3, 0.2, 0.1])
|
||||
print(la(aa))
|
||||
print(adversarial_g_loss(aa))
|
||||
print(la(bb))
|
||||
print(adversarial_g_loss(bb))
|
||||
|
@ -14,7 +14,6 @@ import torch.nn as nn
|
||||
from codec_datamodule import LibriTTSCodecDataModule
|
||||
from encodec import Encodec
|
||||
from lhotse.utils import fix_random_seed
|
||||
from loss import adopt_weight
|
||||
from scheduler import WarmupCosineLrScheduler
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
@ -189,10 +188,10 @@ def get_params() -> AttributeDict:
|
||||
"audio_normalization": False,
|
||||
"chunk_size": 1.0, # in seconds
|
||||
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
||||
"lambda_wav": 1.0, # loss scaling coefficient for waveform loss
|
||||
"lambda_feat": 3.0, # loss scaling coefficient for feat loss
|
||||
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
|
||||
"lambda_feat": 4.0, # loss scaling coefficient for feat loss
|
||||
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
|
||||
"lambda_com": 100.0, # loss scaling coefficient for commitment loss
|
||||
"lambda_com": 1.0, # loss scaling coefficient for commitment loss
|
||||
}
|
||||
)
|
||||
|
||||
@ -361,6 +360,12 @@ def prepare_input(
|
||||
return audio, audio_lens, features, features_lens
|
||||
|
||||
|
||||
def train_discriminator(weight, global_step, threshold=0, value=0.0):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
return weight
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
@ -447,7 +452,7 @@ def train_one_epoch(
|
||||
|
||||
try:
|
||||
with autocast(enabled=params.use_fp16):
|
||||
d_weight = adopt_weight(
|
||||
d_weight = train_discriminator(
|
||||
params.lambda_adv,
|
||||
params.cur_epoch,
|
||||
threshold=params.discriminator_epoch_start,
|
||||
@ -483,7 +488,7 @@ def train_one_epoch(
|
||||
scaler.step(optimizer_d)
|
||||
|
||||
with autocast(enabled=params.use_fp16):
|
||||
g_weight = adopt_weight(
|
||||
g_weight = train_discriminator(
|
||||
params.lambda_adv,
|
||||
params.cur_epoch,
|
||||
threshold=params.discriminator_epoch_start,
|
||||
@ -702,7 +707,7 @@ def compute_validation_loss(
|
||||
loss_info = MetricsTracker()
|
||||
loss_info["samples"] = batch_size
|
||||
|
||||
d_weight = adopt_weight(
|
||||
d_weight = train_discriminator(
|
||||
params.lambda_adv,
|
||||
params.cur_epoch,
|
||||
threshold=params.discriminator_epoch_start,
|
||||
@ -735,7 +740,7 @@ def compute_validation_loss(
|
||||
for k, v in stats_d.items():
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
g_weight = adopt_weight(
|
||||
g_weight = train_discriminator(
|
||||
params.lambda_adv,
|
||||
params.cur_epoch,
|
||||
threshold=params.discriminator_epoch_start,
|
||||
@ -845,7 +850,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
+ disc_period_fake_adv_loss
|
||||
+ disc_scale_real_adv_loss
|
||||
+ disc_scale_fake_adv_loss
|
||||
) * adopt_weight(
|
||||
) * train_discriminator(
|
||||
params.lambda_adv,
|
||||
params.cur_epoch,
|
||||
threshold=params.discriminator_train_start,
|
||||
@ -873,7 +878,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
)
|
||||
loss_g = (
|
||||
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
||||
* adopt_weight(
|
||||
* train_discriminator(
|
||||
params.lambda_adv,
|
||||
0,
|
||||
threshold=params.discriminator_epoch_start,
|
||||
|
Loading…
x
Reference in New Issue
Block a user