From 0150961a335e877d502a4d3bb51489484222ef4d Mon Sep 17 00:00:00 2001 From: JinZr Date: Fri, 6 Sep 2024 21:20:45 +0800 Subject: [PATCH] minor fixes --- egs/libritts/CODEC/encodec/encodec.py | 1 - egs/libritts/CODEC/encodec/loss.py | 30 --------------------------- egs/libritts/CODEC/encodec/train.py | 7 ++++--- 3 files changed, 4 insertions(+), 34 deletions(-) diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 32d80eb38..bde03034f 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -146,7 +146,6 @@ class Encodec(nn.Module): # reset cache if reuse_cache or not self.training: self._cache = None - return loss, stats def _forward_discriminator( diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 0614abf92..96300e9d6 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -30,35 +30,6 @@ def sim_loss(y_disc_r, y_disc_gen): return loss / len(y_disc_r) -# def sisnr_loss(x, s, eps=1e-8): -# """ -# calculate training loss -# input: -# x: separated signal, N x S tensor, estimate value -# s: reference signal, N x S tensor, True value -# Return: -# sisnr: N tensor -# """ -# if x.shape != s.shape: -# if x.shape[-1] > s.shape[-1]: -# x = x[:, :s.shape[-1]] -# else: -# s = s[:, :x.shape[-1]] -# def l2norm(mat, keepdim=False): -# return torch.norm(mat, dim=-1, keepdim=keepdim) -# if x.shape != s.shape: -# raise RuntimeError( -# "Dimention mismatch when calculate si-snr, {} vs {}".format( -# x.shape, s.shape)) -# x_zm = x - torch.mean(x, dim=-1, keepdim=True) -# s_zm = s - torch.mean(s, dim=-1, keepdim=True) -# t = torch.sum( -# x_zm * s_zm, dim=-1, -# keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps) -# loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps)) -# return torch.sum(loss) / x.shape[0] - - def reconstruction_loss(x, x_hat, args, eps=1e-7): # NOTE (lsx): hard-coded now L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss @@ -169,7 +140,6 @@ def adopt_weight(weight, global_step, threshold=0, value=0.0): def adopt_dis_weight(weight, global_step, threshold=0, value=0.0): - # 0,3,6,9,13....这些时间步,不更新dis if global_step % 3 == 0: weight = value return weight diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index bc39da877..e207b12f7 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -559,7 +559,7 @@ def train_one_epoch( logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) - if tb_writer is not None and rank == 0 and speech_hat is not None: + if tb_writer is not None and rank == 0: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) @@ -634,11 +634,12 @@ def compute_validation_loss( speech_lengths=audio_lens, global_step=params.batch_idx_train, forward_generator=True, - return_sample=False, + return_sample=True, ) assert loss_g.requires_grad is False for k, v in stats_g.items(): - loss_info[k] = v * batch_size + if "returned_sample" not in k: + loss_info[k] = v * batch_size # summary stats tot_loss = tot_loss + loss_info