mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Memory fix in WithLoss
This commit is contained in:
parent
0edaf4d25c
commit
2d0fe7637c
@ -911,17 +911,14 @@ class Whiten(nn.Module):
|
|||||||
class WithLoss(torch.autograd.Function):
|
class WithLoss(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, y: Tensor, name: str):
|
def forward(ctx, x: Tensor, y: Tensor, name: str):
|
||||||
ctx.name = name
|
ctx.y_shape = y.shape
|
||||||
ctx.save_for_backward(y) # just for printing the name, and the shape
|
if random.random() < 0.002 and name is not None:
|
||||||
|
loss_sum = y.sum().item()
|
||||||
|
logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}")
|
||||||
return x
|
return x
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, ans_grad: Tensor):
|
def backward(ctx, ans_grad: Tensor):
|
||||||
y, = ctx.saved_tensors
|
return ans_grad, torch.ones(ctx.y_shape,
|
||||||
if random.random() < 0.002 and ctx.name is not None:
|
|
||||||
loss_sum = y.sum().item()
|
|
||||||
logging.info(f"WithLoss: name={ctx.name}, loss-sum={loss_sum:.3e}")
|
|
||||||
|
|
||||||
return ans_grad, torch.ones(y.shape,
|
|
||||||
dtype=ans_grad.dtype,
|
dtype=ans_grad.dtype,
|
||||||
device=ans_grad.device), None
|
device=ans_grad.device), None
|
||||||
def with_loss(x, y, name):
|
def with_loss(x, y, name):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user