Fixes for half precision

This commit is contained in:
Daniel Povey 2022-11-25 16:07:47 +08:00
parent 6a91f343e9
commit 0a997d64c4

View File

@ -423,9 +423,9 @@ class LinearWithAuxLossFunction(torch.autograd.Function):
weight_aux_grad = weight.grad weight_aux_grad = weight.grad
alpha_grad = alpha.grad alpha_grad = alpha.grad
x_grad = torch.matmul(ans_grad, weight) x_grad = torch.matmul(ans_grad, weight.to(ans_grad.dtype))
weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(), weight_grad = torch.matmul(ans_grad.reshape(-1, ans_grad.shape[-1]).t(),
x.reshape(-1, x.shape[-1])) x.reshape(-1, x.shape[-1]).to(ans_grad.dtype))
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
weight_grad_norm = weight_grad.to(torch.float32).norm() weight_grad_norm = weight_grad.to(torch.float32).norm()