mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp151' into scaled_adam_exp150
This commit is contained in:
commit
f08a869769
@ -255,10 +255,13 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, dim: int):
|
||||
ans = x.softmax(dim=dim)
|
||||
ctx.save_for_backward(ans)
|
||||
ctx.dim = dim
|
||||
return ans
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
if x.dtype == torch.float16:
|
||||
x = x.to(torch.float32)
|
||||
ans = x.softmax(dim=dim)
|
||||
ctx.save_for_backward(ans)
|
||||
ctx.dim = dim
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user