mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
putting back custom_bwd, custom_fwd
This commit is contained in:
parent
77357ebb06
commit
c3acfcfa6c
@ -24,7 +24,7 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.amp import custom_bwd, custom_fwd
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
from icefall.utils import torch_autocast
|
from icefall.utils import torch_autocast
|
||||||
|
|
||||||
@ -1306,7 +1306,7 @@ class MulForDropout3(torch.autograd.Function):
|
|||||||
# returns (x * y * alpha) where alpha is a float and y doesn't require
|
# returns (x * y * alpha) where alpha is a float and y doesn't require
|
||||||
# grad and is zero-or-one.
|
# grad and is zero-or-one.
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(device_type='cuda')
|
@custom_fwd
|
||||||
def forward(ctx, x, y, alpha):
|
def forward(ctx, x, y, alpha):
|
||||||
assert not y.requires_grad
|
assert not y.requires_grad
|
||||||
ans = x * y * alpha
|
ans = x * y * alpha
|
||||||
@ -1315,7 +1315,7 @@ class MulForDropout3(torch.autograd.Function):
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd(device_type='cuda')
|
@custom_bwd
|
||||||
def backward(ctx, ans_grad):
|
def backward(ctx, ans_grad):
|
||||||
(ans,) = ctx.saved_tensors
|
(ans,) = ctx.saved_tensors
|
||||||
x_grad = ctx.alpha * ans_grad * (ans != 0)
|
x_grad = ctx.alpha * ans_grad * (ans != 0)
|
||||||
@ -1512,7 +1512,7 @@ def SwooshRForward(x: Tensor):
|
|||||||
|
|
||||||
class ActivationDropoutAndLinearFunction(torch.autograd.Function):
|
class ActivationDropoutAndLinearFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(device_type='cuda')
|
@custom_fwd
|
||||||
def forward(
|
def forward(
|
||||||
ctx,
|
ctx,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@ -1551,7 +1551,7 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd(device_type='cuda')
|
@custom_bwd
|
||||||
def backward(ctx, ans_grad: Tensor):
|
def backward(ctx, ans_grad: Tensor):
|
||||||
saved = ctx.saved_tensors
|
saved = ctx.saved_tensors
|
||||||
(x, weight, bias, dropout_mask) = saved
|
(x, weight, bias, dropout_mask) = saved
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user