mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fixed loss normalization & scaling factors
This commit is contained in:
parent
e788bb4853
commit
d83ce89fca
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from loss import (
|
from loss import (
|
||||||
DiscriminatorAdversarialLoss,
|
DiscriminatorAdversarialLoss,
|
||||||
FeatureMatchLoss,
|
FeatureLoss,
|
||||||
GeneratorAdversarialLoss,
|
GeneratorAdversarialLoss,
|
||||||
MelSpectrogramReconstructionLoss,
|
MelSpectrogramReconstructionLoss,
|
||||||
WavReconstructionLoss,
|
WavReconstructionLoss,
|
||||||
@ -60,7 +60,7 @@ class Encodec(nn.Module):
|
|||||||
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
|
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
|
||||||
average_by_discriminators=True, loss_type="hinge"
|
average_by_discriminators=True, loss_type="hinge"
|
||||||
)
|
)
|
||||||
self.feature_match_loss = FeatureMatchLoss(average_by_layers=False)
|
self.feature_match_loss = FeatureLoss(average_by_layers=False)
|
||||||
self.wav_reconstruction_loss = WavReconstructionLoss()
|
self.wav_reconstruction_loss = WavReconstructionLoss()
|
||||||
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
|
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
|
||||||
sampling_rate=self.sampling_rate
|
sampling_rate=self.sampling_rate
|
||||||
|
@ -57,7 +57,7 @@ class GeneratorAdversarialLoss(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
adv_loss = self.criterion(outputs)
|
adv_loss = self.criterion(outputs)
|
||||||
|
|
||||||
return adv_loss
|
return adv_loss / len(outputs)
|
||||||
|
|
||||||
def _mse_loss(self, x):
|
def _mse_loss(self, x):
|
||||||
return F.mse_loss(x, x.new_ones(x.size()))
|
return F.mse_loss(x, x.new_ones(x.size()))
|
||||||
@ -129,7 +129,7 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
|
|||||||
real_loss = self.real_criterion(outputs)
|
real_loss = self.real_criterion(outputs)
|
||||||
fake_loss = self.fake_criterion(outputs_hat)
|
fake_loss = self.fake_criterion(outputs_hat)
|
||||||
|
|
||||||
return real_loss, fake_loss
|
return real_loss / len(outputs), fake_loss / len(outputs)
|
||||||
|
|
||||||
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return F.mse_loss(x, x.new_ones(x.size()))
|
return F.mse_loss(x, x.new_ones(x.size()))
|
||||||
@ -144,14 +144,14 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
|
|||||||
return F.relu(x.new_ones(x.size()) + x).mean()
|
return F.relu(x.new_ones(x.size()) + x).mean()
|
||||||
|
|
||||||
|
|
||||||
class FeatureMatchLoss(torch.nn.Module):
|
class FeatureLoss(torch.nn.Module):
|
||||||
"""Feature matching loss module."""
|
"""Feature loss module."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
average_by_layers: bool = True,
|
average_by_layers: bool = True,
|
||||||
average_by_discriminators: bool = True,
|
average_by_discriminators: bool = True,
|
||||||
include_final_outputs: bool = False,
|
include_final_outputs: bool = True,
|
||||||
):
|
):
|
||||||
"""Initialize FeatureMatchLoss module.
|
"""Initialize FeatureMatchLoss module.
|
||||||
|
|
||||||
@ -195,14 +195,16 @@ class FeatureMatchLoss(torch.nn.Module):
|
|||||||
feats_hat_ = feats_hat_[:-1]
|
feats_hat_ = feats_hat_[:-1]
|
||||||
feats_ = feats_[:-1]
|
feats_ = feats_[:-1]
|
||||||
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
||||||
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
|
feat_match_loss_ += (
|
||||||
|
(feat_hat_ - feat_).abs() / (feat_.abs().mean())
|
||||||
|
).mean()
|
||||||
if self.average_by_layers:
|
if self.average_by_layers:
|
||||||
feat_match_loss_ /= j + 1
|
feat_match_loss_ /= j + 1
|
||||||
feat_match_loss += feat_match_loss_
|
feat_match_loss += feat_match_loss_
|
||||||
if self.average_by_discriminators:
|
if self.average_by_discriminators:
|
||||||
feat_match_loss /= i + 1
|
feat_match_loss /= i + 1
|
||||||
|
|
||||||
return feat_match_loss
|
return feat_match_loss / (len(feats) * len(feats[0]))
|
||||||
|
|
||||||
|
|
||||||
class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
||||||
@ -231,7 +233,7 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
|||||||
self.wav_to_specs.append(
|
self.wav_to_specs.append(
|
||||||
MelSpectrogram(
|
MelSpectrogram(
|
||||||
sample_rate=sampling_rate,
|
sample_rate=sampling_rate,
|
||||||
n_fft=max(s, 512),
|
n_fft=s,
|
||||||
win_length=s,
|
win_length=s,
|
||||||
hop_length=s // 4,
|
hop_length=s // 4,
|
||||||
n_mels=n_mels,
|
n_mels=n_mels,
|
||||||
@ -266,8 +268,9 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
|||||||
mel_hat = wav_to_spec(x_hat.squeeze(1))
|
mel_hat = wav_to_spec(x_hat.squeeze(1))
|
||||||
mel = wav_to_spec(x.squeeze(1))
|
mel = wav_to_spec(x.squeeze(1))
|
||||||
|
|
||||||
alpha = (s / 2) ** 0.5
|
mel_loss += F.l1_loss(
|
||||||
mel_loss += F.l1_loss(mel_hat, mel) + alpha * F.mse_loss(mel_hat, mel)
|
mel_hat, mel, reduce=True, reduction="mean"
|
||||||
|
) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean")
|
||||||
|
|
||||||
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
|
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
|
||||||
# mel = self.wav_to_spec(x.squeeze(1))
|
# mel = self.wav_to_spec(x.squeeze(1))
|
||||||
@ -300,7 +303,7 @@ class WavReconstructionLoss(torch.nn.Module):
|
|||||||
Tensor: Wav loss value.
|
Tensor: Wav loss value.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
wav_loss = F.mse_loss(x, x_hat)
|
wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean")
|
||||||
|
|
||||||
return wav_loss
|
return wav_loss
|
||||||
|
|
||||||
@ -459,7 +462,7 @@ def loss_g(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
la = FeatureMatchLoss(average_by_layers=False, average_by_discriminators=False)
|
la = FeatureLoss(average_by_layers=False, average_by_discriminators=False)
|
||||||
aa = [torch.rand(192, 192) for _ in range(3)]
|
aa = [torch.rand(192, 192) for _ in range(3)]
|
||||||
bb = [torch.rand(192, 192) for _ in range(3)]
|
bb = [torch.rand(192, 192) for _ in range(3)]
|
||||||
print(la(bb, aa))
|
print(la(bb, aa))
|
||||||
|
@ -187,11 +187,11 @@ def get_params() -> AttributeDict:
|
|||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
"sampling_rate": 24000,
|
"sampling_rate": 24000,
|
||||||
"chunk_size": 1.0, # in seconds
|
"chunk_size": 1.0, # in seconds
|
||||||
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
||||||
"lambda_wav": 100.0, # loss scaling coefficient for waveform loss
|
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
|
||||||
"lambda_feat": 1.0, # loss scaling coefficient for feat loss
|
"lambda_feat": 3.0, # loss scaling coefficient for feat loss
|
||||||
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
|
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
|
||||||
"lambda_com": 1000.0, # loss scaling coefficient for commitment loss
|
"lambda_com": 1.0, # loss scaling coefficient for commitment loss
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -502,11 +502,11 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
reconstruction_loss = (
|
reconstruction_loss = (
|
||||||
params.lambda_wav * wav_reconstruction_loss
|
params.lambda_wav * wav_reconstruction_loss
|
||||||
+ mel_reconstruction_loss
|
+ params.lambda_rec * mel_reconstruction_loss
|
||||||
)
|
)
|
||||||
gen_loss = (
|
gen_loss = (
|
||||||
gen_adv_loss
|
gen_adv_loss
|
||||||
+ params.lambda_rec * reconstruction_loss
|
+ reconstruction_loss
|
||||||
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
|
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
|
||||||
+ params.lambda_com * commit_loss
|
+ params.lambda_com * commit_loss
|
||||||
)
|
)
|
||||||
@ -747,11 +747,12 @@ def compute_validation_loss(
|
|||||||
) * g_weight
|
) * g_weight
|
||||||
feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss
|
feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss
|
||||||
reconstruction_loss = (
|
reconstruction_loss = (
|
||||||
params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss
|
params.lambda_wav * wav_reconstruction_loss
|
||||||
|
+ params.lambda_rec * mel_reconstruction_loss
|
||||||
)
|
)
|
||||||
gen_loss = (
|
gen_loss = (
|
||||||
gen_adv_loss
|
gen_adv_loss
|
||||||
+ params.lambda_rec * reconstruction_loss
|
+ reconstruction_loss
|
||||||
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
|
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
|
||||||
+ params.lambda_com * commit_loss
|
+ params.lambda_com * commit_loss
|
||||||
)
|
)
|
||||||
@ -861,10 +862,9 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
threshold=params.discriminator_iter_start,
|
threshold=params.discriminator_iter_start,
|
||||||
)
|
)
|
||||||
+ params.lambda_rec
|
+ (
|
||||||
* (
|
|
||||||
params.lambda_wav * wav_reconstruction_loss
|
params.lambda_wav * wav_reconstruction_loss
|
||||||
+ mel_reconstruction_loss
|
+ params.lambda_rec * mel_reconstruction_loss
|
||||||
)
|
)
|
||||||
+ params.lambda_feat
|
+ params.lambda_feat
|
||||||
* (feature_stft_loss + feature_period_loss + feature_scale_loss)
|
* (feature_stft_loss + feature_period_loss + feature_scale_loss)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user