mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Bug fixes to LinearWithAuxLoss
This commit is contained in:
parent
0a997d64c4
commit
1ebc3dd158
@ -404,8 +404,15 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]:
|
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, Tensor, Tensor, None]:
|
||||||
x, weight, alpha = ctx.saved_tensors
|
x, weight, alpha = ctx.saved_tensors
|
||||||
|
|
||||||
|
x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype))
|
||||||
|
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
|
||||||
|
x.reshape(-1, x.shape[-1]).to(ans_grad.dtype))
|
||||||
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
|
x = x.to(weight.dtype)
|
||||||
x, weight, alpha = x.detach(), weight.detach(), alpha.detach()
|
x, weight, alpha = x.detach(), weight.detach(), alpha.detach()
|
||||||
weight.requires_grad = True
|
weight.requires_grad = True
|
||||||
alpha.requires_grad = True
|
alpha.requires_grad = True
|
||||||
@ -423,9 +430,6 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
|
|||||||
weight_aux_grad = weight.grad
|
weight_aux_grad = weight.grad
|
||||||
alpha_grad = alpha.grad
|
alpha_grad = alpha.grad
|
||||||
|
|
||||||
x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype))
|
|
||||||
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
|
|
||||||
x.reshape(-1, x.shape[-1]).to(ans_grad.dtype))
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
weight_grad_norm = weight_grad.to(torch.float32).norm()
|
weight_grad_norm = weight_grad.to(torch.float32).norm()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user