Bug fix to printing code

This commit is contained in:
Daniel Povey 2022-12-15 21:55:23 +08:00
parent 076b18db60
commit f66c1600f4

View File

@ -914,7 +914,7 @@ class WithLoss(torch.autograd.Function):
ctx.y_shape = y.shape
if random.random() < 0.002 and name is not None:
loss_sum = y.sum().item()
logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}")
logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
return x
@staticmethod
def backward(ctx, ans_grad: Tensor):