Memory fix in WithLoss

This commit is contained in:
Daniel Povey 2022-12-11 11:15:56 +08:00
parent 0edaf4d25c
commit 2d0fe7637c

View File

@ -911,17 +911,14 @@ class Whiten(nn.Module):
class WithLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, y: Tensor, name: str):
ctx.name = name
ctx.save_for_backward(y) # just for printing the name, and the shape
ctx.y_shape = y.shape
if random.random() < 0.002 and name is not None:
loss_sum = y.sum().item()
logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}")
return x
@staticmethod
def backward(ctx, ans_grad: Tensor):
y, = ctx.saved_tensors
if random.random() < 0.002 and ctx.name is not None:
loss_sum = y.sum().item()
logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}")
return ans_grad, torch.ones(y.shape,
return ans_grad, torch.ones(ctx.y_shape,
dtype=ans_grad.dtype,
device=ans_grad.device), None
def with_loss(x, y, name):