putting back custom_bwd, custom_fwd

This commit is contained in:
Karel Vesely 2025-09-09 10:21:35 +02:00
parent 77357ebb06
commit c3acfcfa6c

View File

@ -24,7 +24,7 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from icefall.utils import torch_autocast from icefall.utils import torch_autocast
@ -1306,7 +1306,7 @@ class MulForDropout3(torch.autograd.Function):
# returns (x * y * alpha) where alpha is a float and y doesn't require # returns (x * y * alpha) where alpha is a float and y doesn't require
# grad and is zero-or-one. # grad and is zero-or-one.
@staticmethod @staticmethod
@custom_fwd(device_type='cuda') @custom_fwd
def forward(ctx, x, y, alpha): def forward(ctx, x, y, alpha):
assert not y.requires_grad assert not y.requires_grad
ans = x * y * alpha ans = x * y * alpha
@ -1315,7 +1315,7 @@ class MulForDropout3(torch.autograd.Function):
return ans return ans
@staticmethod @staticmethod
@custom_bwd(device_type='cuda') @custom_bwd
def backward(ctx, ans_grad): def backward(ctx, ans_grad):
(ans,) = ctx.saved_tensors (ans,) = ctx.saved_tensors
x_grad = ctx.alpha * ans_grad * (ans != 0) x_grad = ctx.alpha * ans_grad * (ans != 0)
@ -1512,7 +1512,7 @@ def SwooshRForward(x: Tensor):
class ActivationDropoutAndLinearFunction(torch.autograd.Function): class ActivationDropoutAndLinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(device_type='cuda') @custom_fwd
def forward( def forward(
ctx, ctx,
x: Tensor, x: Tensor,
@ -1551,7 +1551,7 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function):
return x return x
@staticmethod @staticmethod
@custom_bwd(device_type='cuda') @custom_bwd
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
saved = ctx.saved_tensors saved = ctx.saved_tensors
(x, weight, bias, dropout_mask) = saved (x, weight, bias, dropout_mask) = saved