mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 05:55:26 +00:00
putting back custom_bwd, custom_fwd
This commit is contained in:
parent
77357ebb06
commit
c3acfcfa6c
@ -24,7 +24,7 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.amp import custom_bwd, custom_fwd
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from icefall.utils import torch_autocast
|
||||
|
||||
@ -1306,7 +1306,7 @@ class MulForDropout3(torch.autograd.Function):
|
||||
# returns (x * y * alpha) where alpha is a float and y doesn't require
|
||||
# grad and is zero-or-one.
|
||||
@staticmethod
|
||||
@custom_fwd(device_type='cuda')
|
||||
@custom_fwd
|
||||
def forward(ctx, x, y, alpha):
|
||||
assert not y.requires_grad
|
||||
ans = x * y * alpha
|
||||
@ -1315,7 +1315,7 @@ class MulForDropout3(torch.autograd.Function):
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd(device_type='cuda')
|
||||
@custom_bwd
|
||||
def backward(ctx, ans_grad):
|
||||
(ans,) = ctx.saved_tensors
|
||||
x_grad = ctx.alpha * ans_grad * (ans != 0)
|
||||
@ -1512,7 +1512,7 @@ def SwooshRForward(x: Tensor):
|
||||
|
||||
class ActivationDropoutAndLinearFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(device_type='cuda')
|
||||
@custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
@ -1551,7 +1551,7 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd(device_type='cuda')
|
||||
@custom_bwd
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
saved = ctx.saved_tensors
|
||||
(x, weight, bias, dropout_mask) = saved
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user