Merge branch 'scaled_adam_exp151' into scaled_adam_exp150
This commit is contained in:
commit
f08a869769
@ -255,10 +255,13 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, dim: int):
|
def forward(ctx, x: Tensor, dim: int):
|
||||||
ans = x.softmax(dim=dim)
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
ctx.save_for_backward(ans)
|
if x.dtype == torch.float16:
|
||||||
ctx.dim = dim
|
x = x.to(torch.float32)
|
||||||
return ans
|
ans = x.softmax(dim=dim)
|
||||||
|
ctx.save_for_backward(ans)
|
||||||
|
ctx.dim = dim
|
||||||
|
return ans
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, ans_grad: Tensor):
|
def backward(ctx, ans_grad: Tensor):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user