mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
minor updates
This commit is contained in:
parent
43267e3e29
commit
2356621059
@ -157,27 +157,7 @@ class Encodec(nn.Module):
|
|||||||
x=speech, x_hat=speech_hat
|
x=speech, x_hat=speech_hat
|
||||||
)
|
)
|
||||||
|
|
||||||
# loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g(
|
|
||||||
# commit_loss,
|
|
||||||
# speech,
|
|
||||||
# speech_hat,
|
|
||||||
# fmap,
|
|
||||||
# fmap_hat,
|
|
||||||
# y,
|
|
||||||
# y_hat,
|
|
||||||
# y_p,
|
|
||||||
# y_p_hat,
|
|
||||||
# y_s,
|
|
||||||
# y_s_hat,
|
|
||||||
# fmap_p,
|
|
||||||
# fmap_p_hat,
|
|
||||||
# fmap_s,
|
|
||||||
# fmap_s_hat,
|
|
||||||
# args=self.params,
|
|
||||||
# )
|
|
||||||
|
|
||||||
stats = dict(
|
stats = dict(
|
||||||
# generator_loss=loss.item(),
|
|
||||||
generator_wav_reconstruction_loss=wav_reconstruction_loss.item(),
|
generator_wav_reconstruction_loss=wav_reconstruction_loss.item(),
|
||||||
generator_mel_reconstruction_loss=mel_reconstruction_loss.item(),
|
generator_mel_reconstruction_loss=mel_reconstruction_loss.item(),
|
||||||
generator_feature_stft_loss=feature_stft_loss.item(),
|
generator_feature_stft_loss=feature_stft_loss.item(),
|
||||||
@ -187,7 +167,6 @@ class Encodec(nn.Module):
|
|||||||
generator_period_adv_loss=gen_period_adv_loss.item(),
|
generator_period_adv_loss=gen_period_adv_loss.item(),
|
||||||
generator_scale_adv_loss=gen_scale_adv_loss.item(),
|
generator_scale_adv_loss=gen_scale_adv_loss.item(),
|
||||||
generator_commit_loss=commit_loss.item(),
|
generator_commit_loss=commit_loss.item(),
|
||||||
# d_weight=d_weight.item(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if return_sample:
|
if return_sample:
|
||||||
@ -260,18 +239,16 @@ class Encodec(nn.Module):
|
|||||||
speech_hat.contiguous().detach()
|
speech_hat.contiguous().detach()
|
||||||
)
|
)
|
||||||
|
|
||||||
disc_period_real_adv_loss, disc_period_fake_adv_loss = torch.tensor(
|
disc_period_real_adv_loss = torch.tensor(0.0)
|
||||||
0.0
|
disc_period_fake_adv_loss = torch.tensor(0.0)
|
||||||
), torch.tensor(0.0)
|
|
||||||
if self.multi_period_discriminator is not None:
|
if self.multi_period_discriminator is not None:
|
||||||
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator(
|
||||||
speech.contiguous(),
|
speech.contiguous(),
|
||||||
speech_hat.contiguous().detach(),
|
speech_hat.contiguous().detach(),
|
||||||
)
|
)
|
||||||
|
|
||||||
disc_scale_real_adv_loss, disc_scale_fake_adv_loss = torch.tensor(
|
disc_scale_real_adv_loss = torch.tensor(0.0)
|
||||||
0.0
|
disc_scale_fake_adv_loss = torch.tensor(0.0)
|
||||||
), torch.tensor(0.0)
|
|
||||||
if self.multi_scale_discriminator is not None:
|
if self.multi_scale_discriminator is not None:
|
||||||
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator(
|
||||||
speech.contiguous(),
|
speech.contiguous(),
|
||||||
|
@ -317,171 +317,3 @@ class WavReconstructionLoss(torch.nn.Module):
|
|||||||
wav_loss = F.l1_loss(x, x_hat)
|
wav_loss = F.l1_loss(x, x_hat)
|
||||||
|
|
||||||
return wav_loss
|
return wav_loss
|
||||||
|
|
||||||
|
|
||||||
def adversarial_g_loss(y_disc_gen):
|
|
||||||
"""Hinge loss"""
|
|
||||||
loss = 0.0
|
|
||||||
for i in range(len(y_disc_gen)):
|
|
||||||
stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze()
|
|
||||||
loss += stft_loss
|
|
||||||
return loss / len(y_disc_gen)
|
|
||||||
|
|
||||||
|
|
||||||
def feature_loss(fmap_r, fmap_gen):
|
|
||||||
loss = 0.0
|
|
||||||
for i in range(len(fmap_r)):
|
|
||||||
for j in range(len(fmap_r[i])):
|
|
||||||
stft_loss = (
|
|
||||||
(fmap_r[i][j] - fmap_gen[i][j]).abs() / (fmap_r[i][j].abs().mean())
|
|
||||||
).mean()
|
|
||||||
loss += stft_loss
|
|
||||||
return loss / (len(fmap_r) * len(fmap_r[0]))
|
|
||||||
|
|
||||||
|
|
||||||
def sim_loss(y_disc_r, y_disc_gen):
|
|
||||||
loss = 0.0
|
|
||||||
for i in range(len(y_disc_r)):
|
|
||||||
loss += F.mse_loss(y_disc_r[i], y_disc_gen[i])
|
|
||||||
return loss / len(y_disc_r)
|
|
||||||
|
|
||||||
|
|
||||||
def reconstruction_loss(x, x_hat, args, eps=1e-7):
|
|
||||||
# NOTE (lsx): hard-coded now
|
|
||||||
L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss
|
|
||||||
# loss_sisnr = sisnr_loss(G_x, x) #
|
|
||||||
# L += 0.01*loss_sisnr
|
|
||||||
# 2^6=64 -> 2^10=1024
|
|
||||||
# NOTE (lsx): add 2^11
|
|
||||||
for i in range(6, 12):
|
|
||||||
# for i in range(5, 12): # Encodec setting
|
|
||||||
s = 2**i
|
|
||||||
melspec = MelSpectrogram(
|
|
||||||
sample_rate=args.sampling_rate,
|
|
||||||
n_fft=max(s, 512),
|
|
||||||
win_length=s,
|
|
||||||
hop_length=s // 4,
|
|
||||||
n_mels=64,
|
|
||||||
wkwargs={"device": x_hat.device},
|
|
||||||
).to(x_hat.device)
|
|
||||||
S_x = melspec(x)
|
|
||||||
S_x_hat = melspec(x_hat)
|
|
||||||
l1_loss = (S_x - S_x_hat).abs().mean()
|
|
||||||
l2_loss = (
|
|
||||||
((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean(
|
|
||||||
dim=-2
|
|
||||||
)
|
|
||||||
** 0.5
|
|
||||||
).mean()
|
|
||||||
|
|
||||||
alpha = (s / 2) ** 0.5
|
|
||||||
L += l1_loss + alpha * l2_loss
|
|
||||||
return L
|
|
||||||
|
|
||||||
|
|
||||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
|
||||||
if global_step < threshold:
|
|
||||||
weight = value
|
|
||||||
return weight
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args):
|
|
||||||
if last_layer is not None:
|
|
||||||
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
|
||||||
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
|
||||||
else:
|
|
||||||
print("last_layer cannot be none")
|
|
||||||
assert 1 == 2
|
|
||||||
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
|
||||||
d_weight = torch.clamp(d_weight, 1.0, 1.0).detach()
|
|
||||||
d_weight = d_weight * args.lambda_adv
|
|
||||||
return d_weight
|
|
||||||
|
|
||||||
|
|
||||||
def loss_g(
|
|
||||||
codebook_loss,
|
|
||||||
speech,
|
|
||||||
speech_hat,
|
|
||||||
fmap,
|
|
||||||
fmap_hat,
|
|
||||||
y,
|
|
||||||
y_hat,
|
|
||||||
y_df,
|
|
||||||
y_df_hat,
|
|
||||||
y_ds,
|
|
||||||
y_ds_hat,
|
|
||||||
fmap_f,
|
|
||||||
fmap_f_hat,
|
|
||||||
fmap_s,
|
|
||||||
fmap_s_hat,
|
|
||||||
args=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
args:
|
|
||||||
codebook_loss: commit loss.
|
|
||||||
speech: ground-truth wav.
|
|
||||||
speech_hat: reconstructed wav.
|
|
||||||
fmap: real stft-D feature map.
|
|
||||||
fmap_hat: fake stft-D feature map.
|
|
||||||
y: real stft-D logits.
|
|
||||||
y_hat: fake stft-D logits.
|
|
||||||
global_step: global training step.
|
|
||||||
y_df: real MPD logits.
|
|
||||||
y_df_hat: fake MPD logits.
|
|
||||||
y_ds: real MSD logits.
|
|
||||||
y_ds_hat: fake MSD logits.
|
|
||||||
fmap_f: real MPD feature map.
|
|
||||||
fmap_f_hat: fake MPD feature map.
|
|
||||||
fmap_s: real MSD feature map.
|
|
||||||
fmap_s_hat: fake MSD feature map.
|
|
||||||
"""
|
|
||||||
rec_loss = reconstruction_loss(speech.contiguous(), speech_hat.contiguous(), args)
|
|
||||||
adv_g_loss = adversarial_g_loss(y_hat)
|
|
||||||
adv_mpd_loss = adversarial_g_loss(y_df_hat)
|
|
||||||
adv_msd_loss = adversarial_g_loss(y_ds_hat)
|
|
||||||
adv_loss = (
|
|
||||||
adv_g_loss + adv_mpd_loss + adv_msd_loss
|
|
||||||
) / 3.0 # NOTE(lsx): need to divide by 3?
|
|
||||||
feat_loss = feature_loss(
|
|
||||||
fmap, fmap_hat
|
|
||||||
) # + sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits?
|
|
||||||
feat_loss_mpd = feature_loss(
|
|
||||||
fmap_f, fmap_f_hat
|
|
||||||
) # + sim_loss(y_df_hat_r, y_df_hat_g)
|
|
||||||
feat_loss_msd = feature_loss(
|
|
||||||
fmap_s, fmap_s_hat
|
|
||||||
) # + sim_loss(y_ds_hat_r, y_ds_hat_g)
|
|
||||||
feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0
|
|
||||||
d_weight = torch.tensor(1.0)
|
|
||||||
|
|
||||||
# disc_factor = adopt_weight(
|
|
||||||
# args.lambda_adv, global_step, threshold=args.discriminator_iter_start
|
|
||||||
# )
|
|
||||||
disc_factor = 1
|
|
||||||
if disc_factor == 0.0:
|
|
||||||
fm_loss_wt = 0
|
|
||||||
else:
|
|
||||||
fm_loss_wt = args.lambda_feat
|
|
||||||
|
|
||||||
loss = (
|
|
||||||
rec_loss
|
|
||||||
+ d_weight * disc_factor * adv_loss
|
|
||||||
+ fm_loss_wt * feat_loss_tot
|
|
||||||
+ args.lambda_com * codebook_loss
|
|
||||||
)
|
|
||||||
return loss, rec_loss, adv_loss, feat_loss_tot, d_weight
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# la = FeatureLoss(average_by_layers=True, average_by_discriminators=True)
|
|
||||||
# aa = [torch.rand(192, 192) for _ in range(3)]
|
|
||||||
# bb = [torch.rand(192, 192) for _ in range(3)]
|
|
||||||
# print(la(bb, aa))
|
|
||||||
# print(feature_loss(aa, bb))
|
|
||||||
la = GeneratorAdversarialLoss(average_by_discriminators=True, loss_type="hinge")
|
|
||||||
aa = torch.Tensor([0.1, 0.2, 0.3, 0.4])
|
|
||||||
bb = torch.Tensor([0.4, 0.3, 0.2, 0.1])
|
|
||||||
print(la(aa))
|
|
||||||
print(adversarial_g_loss(aa))
|
|
||||||
print(la(bb))
|
|
||||||
print(adversarial_g_loss(bb))
|
|
||||||
|
@ -14,7 +14,6 @@ import torch.nn as nn
|
|||||||
from codec_datamodule import LibriTTSCodecDataModule
|
from codec_datamodule import LibriTTSCodecDataModule
|
||||||
from encodec import Encodec
|
from encodec import Encodec
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from loss import adopt_weight
|
|
||||||
from scheduler import WarmupCosineLrScheduler
|
from scheduler import WarmupCosineLrScheduler
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.cuda.amp import GradScaler, autocast
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
@ -189,10 +188,10 @@ def get_params() -> AttributeDict:
|
|||||||
"audio_normalization": False,
|
"audio_normalization": False,
|
||||||
"chunk_size": 1.0, # in seconds
|
"chunk_size": 1.0, # in seconds
|
||||||
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
||||||
"lambda_wav": 1.0, # loss scaling coefficient for waveform loss
|
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
|
||||||
"lambda_feat": 3.0, # loss scaling coefficient for feat loss
|
"lambda_feat": 4.0, # loss scaling coefficient for feat loss
|
||||||
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
|
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
|
||||||
"lambda_com": 100.0, # loss scaling coefficient for commitment loss
|
"lambda_com": 1.0, # loss scaling coefficient for commitment loss
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -361,6 +360,12 @@ def prepare_input(
|
|||||||
return audio, audio_lens, features, features_lens
|
return audio, audio_lens, features, features_lens
|
||||||
|
|
||||||
|
|
||||||
|
def train_discriminator(weight, global_step, threshold=0, value=0.0):
|
||||||
|
if global_step < threshold:
|
||||||
|
weight = value
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
@ -447,7 +452,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
d_weight = adopt_weight(
|
d_weight = train_discriminator(
|
||||||
params.lambda_adv,
|
params.lambda_adv,
|
||||||
params.cur_epoch,
|
params.cur_epoch,
|
||||||
threshold=params.discriminator_epoch_start,
|
threshold=params.discriminator_epoch_start,
|
||||||
@ -483,7 +488,7 @@ def train_one_epoch(
|
|||||||
scaler.step(optimizer_d)
|
scaler.step(optimizer_d)
|
||||||
|
|
||||||
with autocast(enabled=params.use_fp16):
|
with autocast(enabled=params.use_fp16):
|
||||||
g_weight = adopt_weight(
|
g_weight = train_discriminator(
|
||||||
params.lambda_adv,
|
params.lambda_adv,
|
||||||
params.cur_epoch,
|
params.cur_epoch,
|
||||||
threshold=params.discriminator_epoch_start,
|
threshold=params.discriminator_epoch_start,
|
||||||
@ -702,7 +707,7 @@ def compute_validation_loss(
|
|||||||
loss_info = MetricsTracker()
|
loss_info = MetricsTracker()
|
||||||
loss_info["samples"] = batch_size
|
loss_info["samples"] = batch_size
|
||||||
|
|
||||||
d_weight = adopt_weight(
|
d_weight = train_discriminator(
|
||||||
params.lambda_adv,
|
params.lambda_adv,
|
||||||
params.cur_epoch,
|
params.cur_epoch,
|
||||||
threshold=params.discriminator_epoch_start,
|
threshold=params.discriminator_epoch_start,
|
||||||
@ -735,7 +740,7 @@ def compute_validation_loss(
|
|||||||
for k, v in stats_d.items():
|
for k, v in stats_d.items():
|
||||||
loss_info[k] = v * batch_size
|
loss_info[k] = v * batch_size
|
||||||
|
|
||||||
g_weight = adopt_weight(
|
g_weight = train_discriminator(
|
||||||
params.lambda_adv,
|
params.lambda_adv,
|
||||||
params.cur_epoch,
|
params.cur_epoch,
|
||||||
threshold=params.discriminator_epoch_start,
|
threshold=params.discriminator_epoch_start,
|
||||||
@ -845,7 +850,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
+ disc_period_fake_adv_loss
|
+ disc_period_fake_adv_loss
|
||||||
+ disc_scale_real_adv_loss
|
+ disc_scale_real_adv_loss
|
||||||
+ disc_scale_fake_adv_loss
|
+ disc_scale_fake_adv_loss
|
||||||
) * adopt_weight(
|
) * train_discriminator(
|
||||||
params.lambda_adv,
|
params.lambda_adv,
|
||||||
params.cur_epoch,
|
params.cur_epoch,
|
||||||
threshold=params.discriminator_train_start,
|
threshold=params.discriminator_train_start,
|
||||||
@ -873,7 +878,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
)
|
)
|
||||||
loss_g = (
|
loss_g = (
|
||||||
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
(gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss)
|
||||||
* adopt_weight(
|
* train_discriminator(
|
||||||
params.lambda_adv,
|
params.lambda_adv,
|
||||||
0,
|
0,
|
||||||
threshold=params.discriminator_epoch_start,
|
threshold=params.discriminator_epoch_start,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user