mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
minor fixes
This commit is contained in:
parent
2e5055a847
commit
0150961a33
@ -146,7 +146,6 @@ class Encodec(nn.Module):
|
||||
# reset cache
|
||||
if reuse_cache or not self.training:
|
||||
self._cache = None
|
||||
|
||||
return loss, stats
|
||||
|
||||
def _forward_discriminator(
|
||||
|
@ -30,35 +30,6 @@ def sim_loss(y_disc_r, y_disc_gen):
|
||||
return loss / len(y_disc_r)
|
||||
|
||||
|
||||
# def sisnr_loss(x, s, eps=1e-8):
|
||||
# """
|
||||
# calculate training loss
|
||||
# input:
|
||||
# x: separated signal, N x S tensor, estimate value
|
||||
# s: reference signal, N x S tensor, True value
|
||||
# Return:
|
||||
# sisnr: N tensor
|
||||
# """
|
||||
# if x.shape != s.shape:
|
||||
# if x.shape[-1] > s.shape[-1]:
|
||||
# x = x[:, :s.shape[-1]]
|
||||
# else:
|
||||
# s = s[:, :x.shape[-1]]
|
||||
# def l2norm(mat, keepdim=False):
|
||||
# return torch.norm(mat, dim=-1, keepdim=keepdim)
|
||||
# if x.shape != s.shape:
|
||||
# raise RuntimeError(
|
||||
# "Dimention mismatch when calculate si-snr, {} vs {}".format(
|
||||
# x.shape, s.shape))
|
||||
# x_zm = x - torch.mean(x, dim=-1, keepdim=True)
|
||||
# s_zm = s - torch.mean(s, dim=-1, keepdim=True)
|
||||
# t = torch.sum(
|
||||
# x_zm * s_zm, dim=-1,
|
||||
# keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
|
||||
# loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))
|
||||
# return torch.sum(loss) / x.shape[0]
|
||||
|
||||
|
||||
def reconstruction_loss(x, x_hat, args, eps=1e-7):
|
||||
# NOTE (lsx): hard-coded now
|
||||
L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss
|
||||
@ -169,7 +140,6 @@ def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
||||
|
||||
|
||||
def adopt_dis_weight(weight, global_step, threshold=0, value=0.0):
|
||||
# 0,3,6,9,13....这些时间步,不更新dis
|
||||
if global_step % 3 == 0:
|
||||
weight = value
|
||||
return weight
|
||||
|
@ -559,7 +559,7 @@ def train_one_epoch(
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None and rank == 0 and speech_hat is not None:
|
||||
if tb_writer is not None and rank == 0:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
@ -634,10 +634,11 @@ def compute_validation_loss(
|
||||
speech_lengths=audio_lens,
|
||||
global_step=params.batch_idx_train,
|
||||
forward_generator=True,
|
||||
return_sample=False,
|
||||
return_sample=True,
|
||||
)
|
||||
assert loss_g.requires_grad is False
|
||||
for k, v in stats_g.items():
|
||||
if "returned_sample" not in k:
|
||||
loss_info[k] = v * batch_size
|
||||
|
||||
# summary stats
|
||||
|
Loading…
x
Reference in New Issue
Block a user