mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fixes for half precision
This commit is contained in:
parent
6a91f343e9
commit
0a997d64c4
@ -423,9 +423,9 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
|
||||
weight_aux_grad = weight.grad
|
||||
alpha_grad = alpha.grad
|
||||
|
||||
x_grad = torch.matmul(ans_grad, weight)
|
||||
x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype))
|
||||
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
|
||||
x.reshape(-1, x.shape[-1]))
|
||||
x.reshape(-1, x.shape[-1]).to(ans_grad.dtype))
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
weight_grad_norm = weight_grad.to(torch.float32).norm()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user