mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fixed loss normalization & scaling factors
This commit is contained in:
parent
e788bb4853
commit
d83ce89fca
@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
from loss import (
|
||||
DiscriminatorAdversarialLoss,
|
||||
FeatureMatchLoss,
|
||||
FeatureLoss,
|
||||
GeneratorAdversarialLoss,
|
||||
MelSpectrogramReconstructionLoss,
|
||||
WavReconstructionLoss,
|
||||
@ -60,7 +60,7 @@ class Encodec(nn.Module):
|
||||
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
|
||||
average_by_discriminators=True, loss_type="hinge"
|
||||
)
|
||||
self.feature_match_loss = FeatureMatchLoss(average_by_layers=False)
|
||||
self.feature_match_loss = FeatureLoss(average_by_layers=False)
|
||||
self.wav_reconstruction_loss = WavReconstructionLoss()
|
||||
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
|
||||
sampling_rate=self.sampling_rate
|
||||
|
@ -57,7 +57,7 @@ class GeneratorAdversarialLoss(torch.nn.Module):
|
||||
else:
|
||||
adv_loss = self.criterion(outputs)
|
||||
|
||||
return adv_loss
|
||||
return adv_loss / len(outputs)
|
||||
|
||||
def _mse_loss(self, x):
|
||||
return F.mse_loss(x, x.new_ones(x.size()))
|
||||
@ -129,7 +129,7 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
|
||||
real_loss = self.real_criterion(outputs)
|
||||
fake_loss = self.fake_criterion(outputs_hat)
|
||||
|
||||
return real_loss, fake_loss
|
||||
return real_loss / len(outputs), fake_loss / len(outputs)
|
||||
|
||||
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.mse_loss(x, x.new_ones(x.size()))
|
||||
@ -144,14 +144,14 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
|
||||
return F.relu(x.new_ones(x.size()) + x).mean()
|
||||
|
||||
|
||||
class FeatureMatchLoss(torch.nn.Module):
|
||||
"""Feature matching loss module."""
|
||||
class FeatureLoss(torch.nn.Module):
|
||||
"""Feature loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_layers: bool = True,
|
||||
average_by_discriminators: bool = True,
|
||||
include_final_outputs: bool = False,
|
||||
include_final_outputs: bool = True,
|
||||
):
|
||||
"""Initialize FeatureMatchLoss module.
|
||||
|
||||
@ -195,14 +195,16 @@ class FeatureMatchLoss(torch.nn.Module):
|
||||
feats_hat_ = feats_hat_[:-1]
|
||||
feats_ = feats_[:-1]
|
||||
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
||||
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
|
||||
feat_match_loss_ += (
|
||||
(feat_hat_ - feat_).abs() / (feat_.abs().mean())
|
||||
).mean()
|
||||
if self.average_by_layers:
|
||||
feat_match_loss_ /= j + 1
|
||||
feat_match_loss += feat_match_loss_
|
||||
if self.average_by_discriminators:
|
||||
feat_match_loss /= i + 1
|
||||
|
||||
return feat_match_loss
|
||||
return feat_match_loss / (len(feats) * len(feats[0]))
|
||||
|
||||
|
||||
class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
||||
@ -231,7 +233,7 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
||||
self.wav_to_specs.append(
|
||||
MelSpectrogram(
|
||||
sample_rate=sampling_rate,
|
||||
n_fft=max(s, 512),
|
||||
n_fft=s,
|
||||
win_length=s,
|
||||
hop_length=s // 4,
|
||||
n_mels=n_mels,
|
||||
@ -266,8 +268,9 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
|
||||
mel_hat = wav_to_spec(x_hat.squeeze(1))
|
||||
mel = wav_to_spec(x.squeeze(1))
|
||||
|
||||
alpha = (s / 2) ** 0.5
|
||||
mel_loss += F.l1_loss(mel_hat, mel) + alpha * F.mse_loss(mel_hat, mel)
|
||||
mel_loss += F.l1_loss(
|
||||
mel_hat, mel, reduce=True, reduction="mean"
|
||||
) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean")
|
||||
|
||||
# mel_hat = self.wav_to_spec(x_hat.squeeze(1))
|
||||
# mel = self.wav_to_spec(x.squeeze(1))
|
||||
@ -300,7 +303,7 @@ class WavReconstructionLoss(torch.nn.Module):
|
||||
Tensor: Wav loss value.
|
||||
|
||||
"""
|
||||
wav_loss = F.mse_loss(x, x_hat)
|
||||
wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean")
|
||||
|
||||
return wav_loss
|
||||
|
||||
@ -459,7 +462,7 @@ def loss_g(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
la = FeatureMatchLoss(average_by_layers=False, average_by_discriminators=False)
|
||||
la = FeatureLoss(average_by_layers=False, average_by_discriminators=False)
|
||||
aa = [torch.rand(192, 192) for _ in range(3)]
|
||||
bb = [torch.rand(192, 192) for _ in range(3)]
|
||||
print(la(bb, aa))
|
||||
|
@ -187,11 +187,11 @@ def get_params() -> AttributeDict:
|
||||
"env_info": get_env_info(),
|
||||
"sampling_rate": 24000,
|
||||
"chunk_size": 1.0, # in seconds
|
||||
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
|
||||
"lambda_wav": 100.0, # loss scaling coefficient for waveform loss
|
||||
"lambda_feat": 1.0, # loss scaling coefficient for feat loss
|
||||
"lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
|
||||
"lambda_wav": 0.1, # loss scaling coefficient for waveform loss
|
||||
"lambda_feat": 3.0, # loss scaling coefficient for feat loss
|
||||
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
|
||||
"lambda_com": 1000.0, # loss scaling coefficient for commitment loss
|
||||
"lambda_com": 1.0, # loss scaling coefficient for commitment loss
|
||||
}
|
||||
)
|
||||
|
||||
@ -502,11 +502,11 @@ def train_one_epoch(
|
||||
)
|
||||
reconstruction_loss = (
|
||||
params.lambda_wav * wav_reconstruction_loss
|
||||
+ mel_reconstruction_loss
|
||||
+ params.lambda_rec * mel_reconstruction_loss
|
||||
)
|
||||
gen_loss = (
|
||||
gen_adv_loss
|
||||
+ params.lambda_rec * reconstruction_loss
|
||||
+ reconstruction_loss
|
||||
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
|
||||
+ params.lambda_com * commit_loss
|
||||
)
|
||||
@ -747,11 +747,12 @@ def compute_validation_loss(
|
||||
) * g_weight
|
||||
feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss
|
||||
reconstruction_loss = (
|
||||
params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss
|
||||
params.lambda_wav * wav_reconstruction_loss
|
||||
+ params.lambda_rec * mel_reconstruction_loss
|
||||
)
|
||||
gen_loss = (
|
||||
gen_adv_loss
|
||||
+ params.lambda_rec * reconstruction_loss
|
||||
+ reconstruction_loss
|
||||
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
|
||||
+ params.lambda_com * commit_loss
|
||||
)
|
||||
@ -861,10 +862,9 @@ def scan_pessimistic_batches_for_oom(
|
||||
params.batch_idx_train,
|
||||
threshold=params.discriminator_iter_start,
|
||||
)
|
||||
+ params.lambda_rec
|
||||
* (
|
||||
+ (
|
||||
params.lambda_wav * wav_reconstruction_loss
|
||||
+ mel_reconstruction_loss
|
||||
+ params.lambda_rec * mel_reconstruction_loss
|
||||
)
|
||||
+ params.lambda_feat
|
||||
* (feature_stft_loss + feature_period_loss + feature_scale_loss)
|
||||
|
Loading…
x
Reference in New Issue
Block a user