putting back custom_bwd, custom_fwd

This commit is contained in:
Karel Vesely 2025-09-09 10:21:35 +02:00
parent 77357ebb06
commit c3acfcfa6c

View File

@ -24,7 +24,7 @@ import k2
import torch
import torch.nn as nn
from torch import Tensor
from torch.amp import custom_bwd, custom_fwd
from torch.cuda.amp import custom_bwd, custom_fwd
from icefall.utils import torch_autocast
@ -1306,7 +1306,7 @@ class MulForDropout3(torch.autograd.Function):
# returns (x * y * alpha) where alpha is a float and y doesn't require
# grad and is zero-or-one.
@staticmethod
@custom_fwd(device_type='cuda')
@custom_fwd
def forward(ctx, x, y, alpha):
assert not y.requires_grad
ans = x * y * alpha
@ -1315,7 +1315,7 @@ class MulForDropout3(torch.autograd.Function):
return ans
@staticmethod
@custom_bwd(device_type='cuda')
@custom_bwd
def backward(ctx, ans_grad):
(ans,) = ctx.saved_tensors
x_grad = ctx.alpha * ans_grad * (ans != 0)
@ -1512,7 +1512,7 @@ def SwooshRForward(x: Tensor):
class ActivationDropoutAndLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(device_type='cuda')
@custom_fwd
def forward(
ctx,
x: Tensor,
@ -1551,7 +1551,7 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function):
return x
@staticmethod
@custom_bwd(device_type='cuda')
@custom_bwd
def backward(ctx, ans_grad: Tensor):
saved = ctx.saved_tensors
(x, weight, bias, dropout_mask) = saved