From 2d0fe7637c15291d096bde7b5e2a6cd91b460d20 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 11 Dec 2022 11:15:56 +0800 Subject: [PATCH] Memory fix in WithLoss --- .../ASR/pruned_transducer_stateless7/scaling.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 3b443e1c1..ced3b96d1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -911,17 +911,14 @@ class Whiten(nn.Module): class WithLoss(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, y: Tensor, name: str): - ctx.name = name - ctx.save_for_backward(y) # just for printing the name, and the shape + ctx.y_shape = y.shape + if random.random() < 0.002 and name is not None: + loss_sum = y.sum().item() + logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}") return x @staticmethod def backward(ctx, ans_grad: Tensor): - y, = ctx.saved_tensors - if random.random() < 0.002 and ctx.name is not None: - loss_sum = y.sum().item() - logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}") - - return ans_grad, torch.ones(y.shape, + return ans_grad, torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), None def with_loss(x, y, name):