fixed loss normalization & scaling factors

This commit is contained in:
JinZr 2024-10-06 15:55:49 +08:00
parent e788bb4853
commit d83ce89fca
3 changed files with 28 additions and 25 deletions

View File

@ -6,7 +6,7 @@ import numpy as np
import torch import torch
from loss import ( from loss import (
DiscriminatorAdversarialLoss, DiscriminatorAdversarialLoss,
FeatureMatchLoss, FeatureLoss,
GeneratorAdversarialLoss, GeneratorAdversarialLoss,
MelSpectrogramReconstructionLoss, MelSpectrogramReconstructionLoss,
WavReconstructionLoss, WavReconstructionLoss,
@ -60,7 +60,7 @@ class Encodec(nn.Module):
self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss( self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss(
average_by_discriminators=True, loss_type="hinge" average_by_discriminators=True, loss_type="hinge"
) )
self.feature_match_loss = FeatureMatchLoss(average_by_layers=False) self.feature_match_loss = FeatureLoss(average_by_layers=False)
self.wav_reconstruction_loss = WavReconstructionLoss() self.wav_reconstruction_loss = WavReconstructionLoss()
self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss( self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss(
sampling_rate=self.sampling_rate sampling_rate=self.sampling_rate

View File

@ -57,7 +57,7 @@ class GeneratorAdversarialLoss(torch.nn.Module):
else: else:
adv_loss = self.criterion(outputs) adv_loss = self.criterion(outputs)
return adv_loss return adv_loss / len(outputs)
def _mse_loss(self, x): def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size())) return F.mse_loss(x, x.new_ones(x.size()))
@ -129,7 +129,7 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
real_loss = self.real_criterion(outputs) real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat) fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss return real_loss / len(outputs), fake_loss / len(outputs)
def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
return F.mse_loss(x, x.new_ones(x.size())) return F.mse_loss(x, x.new_ones(x.size()))
@ -144,14 +144,14 @@ class DiscriminatorAdversarialLoss(torch.nn.Module):
return F.relu(x.new_ones(x.size()) + x).mean() return F.relu(x.new_ones(x.size()) + x).mean()
class FeatureMatchLoss(torch.nn.Module): class FeatureLoss(torch.nn.Module):
"""Feature matching loss module.""" """Feature loss module."""
def __init__( def __init__(
self, self,
average_by_layers: bool = True, average_by_layers: bool = True,
average_by_discriminators: bool = True, average_by_discriminators: bool = True,
include_final_outputs: bool = False, include_final_outputs: bool = True,
): ):
"""Initialize FeatureMatchLoss module. """Initialize FeatureMatchLoss module.
@ -195,14 +195,16 @@ class FeatureMatchLoss(torch.nn.Module):
feats_hat_ = feats_hat_[:-1] feats_hat_ = feats_hat_[:-1]
feats_ = feats_[:-1] feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) feat_match_loss_ += (
(feat_hat_ - feat_).abs() / (feat_.abs().mean())
).mean()
if self.average_by_layers: if self.average_by_layers:
feat_match_loss_ /= j + 1 feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_ feat_match_loss += feat_match_loss_
if self.average_by_discriminators: if self.average_by_discriminators:
feat_match_loss /= i + 1 feat_match_loss /= i + 1
return feat_match_loss return feat_match_loss / (len(feats) * len(feats[0]))
class MelSpectrogramReconstructionLoss(torch.nn.Module): class MelSpectrogramReconstructionLoss(torch.nn.Module):
@ -231,7 +233,7 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
self.wav_to_specs.append( self.wav_to_specs.append(
MelSpectrogram( MelSpectrogram(
sample_rate=sampling_rate, sample_rate=sampling_rate,
n_fft=max(s, 512), n_fft=s,
win_length=s, win_length=s,
hop_length=s // 4, hop_length=s // 4,
n_mels=n_mels, n_mels=n_mels,
@ -266,8 +268,9 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module):
mel_hat = wav_to_spec(x_hat.squeeze(1)) mel_hat = wav_to_spec(x_hat.squeeze(1))
mel = wav_to_spec(x.squeeze(1)) mel = wav_to_spec(x.squeeze(1))
alpha = (s / 2) ** 0.5 mel_loss += F.l1_loss(
mel_loss += F.l1_loss(mel_hat, mel) + alpha * F.mse_loss(mel_hat, mel) mel_hat, mel, reduce=True, reduction="mean"
) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean")
# mel_hat = self.wav_to_spec(x_hat.squeeze(1)) # mel_hat = self.wav_to_spec(x_hat.squeeze(1))
# mel = self.wav_to_spec(x.squeeze(1)) # mel = self.wav_to_spec(x.squeeze(1))
@ -300,7 +303,7 @@ class WavReconstructionLoss(torch.nn.Module):
Tensor: Wav loss value. Tensor: Wav loss value.
""" """
wav_loss = F.mse_loss(x, x_hat) wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean")
return wav_loss return wav_loss
@ -459,7 +462,7 @@ def loss_g(
if __name__ == "__main__": if __name__ == "__main__":
la = FeatureMatchLoss(average_by_layers=False, average_by_discriminators=False) la = FeatureLoss(average_by_layers=False, average_by_discriminators=False)
aa = [torch.rand(192, 192) for _ in range(3)] aa = [torch.rand(192, 192) for _ in range(3)]
bb = [torch.rand(192, 192) for _ in range(3)] bb = [torch.rand(192, 192) for _ in range(3)]
print(la(bb, aa)) print(la(bb, aa))

View File

@ -187,11 +187,11 @@ def get_params() -> AttributeDict:
"env_info": get_env_info(), "env_info": get_env_info(),
"sampling_rate": 24000, "sampling_rate": 24000,
"chunk_size": 1.0, # in seconds "chunk_size": 1.0, # in seconds
"lambda_adv": 1.0, # loss scaling coefficient for adversarial loss "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss
"lambda_wav": 100.0, # loss scaling coefficient for waveform loss "lambda_wav": 0.1, # loss scaling coefficient for waveform loss
"lambda_feat": 1.0, # loss scaling coefficient for feat loss "lambda_feat": 3.0, # loss scaling coefficient for feat loss
"lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss "lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss
"lambda_com": 1000.0, # loss scaling coefficient for commitment loss "lambda_com": 1.0, # loss scaling coefficient for commitment loss
} }
) )
@ -502,11 +502,11 @@ def train_one_epoch(
) )
reconstruction_loss = ( reconstruction_loss = (
params.lambda_wav * wav_reconstruction_loss params.lambda_wav * wav_reconstruction_loss
+ mel_reconstruction_loss + params.lambda_rec * mel_reconstruction_loss
) )
gen_loss = ( gen_loss = (
gen_adv_loss gen_adv_loss
+ params.lambda_rec * reconstruction_loss + reconstruction_loss
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
+ params.lambda_com * commit_loss + params.lambda_com * commit_loss
) )
@ -747,11 +747,12 @@ def compute_validation_loss(
) * g_weight ) * g_weight
feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss
reconstruction_loss = ( reconstruction_loss = (
params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss params.lambda_wav * wav_reconstruction_loss
+ params.lambda_rec * mel_reconstruction_loss
) )
gen_loss = ( gen_loss = (
gen_adv_loss gen_adv_loss
+ params.lambda_rec * reconstruction_loss + reconstruction_loss
+ (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss
+ params.lambda_com * commit_loss + params.lambda_com * commit_loss
) )
@ -861,10 +862,9 @@ def scan_pessimistic_batches_for_oom(
params.batch_idx_train, params.batch_idx_train,
threshold=params.discriminator_iter_start, threshold=params.discriminator_iter_start,
) )
+ params.lambda_rec + (
* (
params.lambda_wav * wav_reconstruction_loss params.lambda_wav * wav_reconstruction_loss
+ mel_reconstruction_loss + params.lambda_rec * mel_reconstruction_loss
) )
+ params.lambda_feat + params.lambda_feat
* (feature_stft_loss + feature_period_loss + feature_scale_loss) * (feature_stft_loss + feature_period_loss + feature_scale_loss)