minor fixes

This commit is contained in:
JinZr 2024-09-06 21:20:45 +08:00
parent 2e5055a847
commit 0150961a33
3 changed files with 4 additions and 34 deletions

View File

@ -146,7 +146,6 @@ class Encodec(nn.Module):
# reset cache # reset cache
if reuse_cache or not self.training: if reuse_cache or not self.training:
self._cache = None self._cache = None
return loss, stats return loss, stats
def _forward_discriminator( def _forward_discriminator(

View File

@ -30,35 +30,6 @@ def sim_loss(y_disc_r, y_disc_gen):
return loss / len(y_disc_r) return loss / len(y_disc_r)
# def sisnr_loss(x, s, eps=1e-8):
# """
# calculate training loss
# input:
# x: separated signal, N x S tensor, estimate value
# s: reference signal, N x S tensor, True value
# Return:
# sisnr: N tensor
# """
# if x.shape != s.shape:
# if x.shape[-1] > s.shape[-1]:
# x = x[:, :s.shape[-1]]
# else:
# s = s[:, :x.shape[-1]]
# def l2norm(mat, keepdim=False):
# return torch.norm(mat, dim=-1, keepdim=keepdim)
# if x.shape != s.shape:
# raise RuntimeError(
# "Dimention mismatch when calculate si-snr, {} vs {}".format(
# x.shape, s.shape))
# x_zm = x - torch.mean(x, dim=-1, keepdim=True)
# s_zm = s - torch.mean(s, dim=-1, keepdim=True)
# t = torch.sum(
# x_zm * s_zm, dim=-1,
# keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
# loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))
# return torch.sum(loss) / x.shape[0]
def reconstruction_loss(x, x_hat, args, eps=1e-7): def reconstruction_loss(x, x_hat, args, eps=1e-7):
# NOTE (lsx): hard-coded now # NOTE (lsx): hard-coded now
L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss
@ -169,7 +140,6 @@ def adopt_weight(weight, global_step, threshold=0, value=0.0):
def adopt_dis_weight(weight, global_step, threshold=0, value=0.0): def adopt_dis_weight(weight, global_step, threshold=0, value=0.0):
# 0,3,6,9,13....这些时间步不更新dis
if global_step % 3 == 0: if global_step % 3 == 0:
weight = value weight = value
return weight return weight

View File

@ -559,7 +559,7 @@ def train_one_epoch(
logging.info( logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
) )
if tb_writer is not None and rank == 0 and speech_hat is not None: if tb_writer is not None and rank == 0:
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
@ -634,11 +634,12 @@ def compute_validation_loss(
speech_lengths=audio_lens, speech_lengths=audio_lens,
global_step=params.batch_idx_train, global_step=params.batch_idx_train,
forward_generator=True, forward_generator=True,
return_sample=False, return_sample=True,
) )
assert loss_g.requires_grad is False assert loss_g.requires_grad is False
for k, v in stats_g.items(): for k, v in stats_g.items():
loss_info[k] = v * batch_size if "returned_sample" not in k:
loss_info[k] = v * batch_size
# summary stats # summary stats
tot_loss = tot_loss + loss_info tot_loss = tot_loss + loss_info