diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index a2e540dcd..725ce5d01 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -6,7 +6,7 @@ import numpy as np import torch from loss import ( DiscriminatorAdversarialLoss, - FeatureMatchLoss, + FeatureLoss, GeneratorAdversarialLoss, MelSpectrogramReconstructionLoss, WavReconstructionLoss, @@ -60,7 +60,7 @@ class Encodec(nn.Module): self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss( average_by_discriminators=True, loss_type="hinge" ) - self.feature_match_loss = FeatureMatchLoss(average_by_layers=False) + self.feature_match_loss = FeatureLoss(average_by_layers=False) self.wav_reconstruction_loss = WavReconstructionLoss() self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss( sampling_rate=self.sampling_rate diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 7e9bf5590..f4188a313 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -57,7 +57,7 @@ class GeneratorAdversarialLoss(torch.nn.Module): else: adv_loss = self.criterion(outputs) - return adv_loss + return adv_loss / len(outputs) def _mse_loss(self, x): return F.mse_loss(x, x.new_ones(x.size())) @@ -129,7 +129,7 @@ class DiscriminatorAdversarialLoss(torch.nn.Module): real_loss = self.real_criterion(outputs) fake_loss = self.fake_criterion(outputs_hat) - return real_loss, fake_loss + return real_loss / len(outputs), fake_loss / len(outputs) def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, x.new_ones(x.size())) @@ -144,14 +144,14 @@ class DiscriminatorAdversarialLoss(torch.nn.Module): return F.relu(x.new_ones(x.size()) + x).mean() -class FeatureMatchLoss(torch.nn.Module): - """Feature matching loss module.""" +class FeatureLoss(torch.nn.Module): + """Feature loss module.""" def __init__( self, average_by_layers: bool = True, average_by_discriminators: bool = True, - include_final_outputs: bool = False, + include_final_outputs: bool = True, ): """Initialize FeatureMatchLoss module. @@ -195,14 +195,16 @@ class FeatureMatchLoss(torch.nn.Module): feats_hat_ = feats_hat_[:-1] feats_ = feats_[:-1] for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): - feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) + feat_match_loss_ += ( + (feat_hat_ - feat_).abs() / (feat_.abs().mean()) + ).mean() if self.average_by_layers: feat_match_loss_ /= j + 1 feat_match_loss += feat_match_loss_ if self.average_by_discriminators: feat_match_loss /= i + 1 - return feat_match_loss + return feat_match_loss / (len(feats) * len(feats[0])) class MelSpectrogramReconstructionLoss(torch.nn.Module): @@ -231,7 +233,7 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module): self.wav_to_specs.append( MelSpectrogram( sample_rate=sampling_rate, - n_fft=max(s, 512), + n_fft=s, win_length=s, hop_length=s // 4, n_mels=n_mels, @@ -266,8 +268,9 @@ class MelSpectrogramReconstructionLoss(torch.nn.Module): mel_hat = wav_to_spec(x_hat.squeeze(1)) mel = wav_to_spec(x.squeeze(1)) - alpha = (s / 2) ** 0.5 - mel_loss += F.l1_loss(mel_hat, mel) + alpha * F.mse_loss(mel_hat, mel) + mel_loss += F.l1_loss( + mel_hat, mel, reduce=True, reduction="mean" + ) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean") # mel_hat = self.wav_to_spec(x_hat.squeeze(1)) # mel = self.wav_to_spec(x.squeeze(1)) @@ -300,7 +303,7 @@ class WavReconstructionLoss(torch.nn.Module): Tensor: Wav loss value. """ - wav_loss = F.mse_loss(x, x_hat) + wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean") return wav_loss @@ -459,7 +462,7 @@ def loss_g( if __name__ == "__main__": - la = FeatureMatchLoss(average_by_layers=False, average_by_discriminators=False) + la = FeatureLoss(average_by_layers=False, average_by_discriminators=False) aa = [torch.rand(192, 192) for _ in range(3)] bb = [torch.rand(192, 192) for _ in range(3)] print(la(bb, aa)) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 0adffb658..5b21c81dd 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -187,11 +187,11 @@ def get_params() -> AttributeDict: "env_info": get_env_info(), "sampling_rate": 24000, "chunk_size": 1.0, # in seconds - "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss - "lambda_wav": 100.0, # loss scaling coefficient for waveform loss - "lambda_feat": 1.0, # loss scaling coefficient for feat loss + "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss + "lambda_wav": 0.1, # loss scaling coefficient for waveform loss + "lambda_feat": 3.0, # loss scaling coefficient for feat loss "lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss - "lambda_com": 1000.0, # loss scaling coefficient for commitment loss + "lambda_com": 1.0, # loss scaling coefficient for commitment loss } ) @@ -502,11 +502,11 @@ def train_one_epoch( ) reconstruction_loss = ( params.lambda_wav * wav_reconstruction_loss - + mel_reconstruction_loss + + params.lambda_rec * mel_reconstruction_loss ) gen_loss = ( gen_adv_loss - + params.lambda_rec * reconstruction_loss + + reconstruction_loss + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + params.lambda_com * commit_loss ) @@ -747,11 +747,12 @@ def compute_validation_loss( ) * g_weight feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss reconstruction_loss = ( - params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss + params.lambda_wav * wav_reconstruction_loss + + params.lambda_rec * mel_reconstruction_loss ) gen_loss = ( gen_adv_loss - + params.lambda_rec * reconstruction_loss + + reconstruction_loss + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + params.lambda_com * commit_loss ) @@ -861,10 +862,9 @@ def scan_pessimistic_batches_for_oom( params.batch_idx_train, threshold=params.discriminator_iter_start, ) - + params.lambda_rec - * ( + + ( params.lambda_wav * wav_reconstruction_loss - + mel_reconstruction_loss + + params.lambda_rec * mel_reconstruction_loss ) + params.lambda_feat * (feature_stft_loss + feature_period_loss + feature_scale_loss)