From c3acfcfa6c29c9e250072a236c3481cdfe1da940 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Tue, 9 Sep 2025 10:21:35 +0200 Subject: [PATCH] putting back `custom_bwd`, `custom_fwd` --- egs/librispeech/ASR/zipformer/scaling.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 5994f01bf..22aa1b1ca 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -24,7 +24,7 @@ import k2 import torch import torch.nn as nn from torch import Tensor -from torch.amp import custom_bwd, custom_fwd +from torch.cuda.amp import custom_bwd, custom_fwd from icefall.utils import torch_autocast @@ -1306,7 +1306,7 @@ class MulForDropout3(torch.autograd.Function): # returns (x * y * alpha) where alpha is a float and y doesn't require # grad and is zero-or-one. @staticmethod - @custom_fwd(device_type='cuda') + @custom_fwd def forward(ctx, x, y, alpha): assert not y.requires_grad ans = x * y * alpha @@ -1315,7 +1315,7 @@ class MulForDropout3(torch.autograd.Function): return ans @staticmethod - @custom_bwd(device_type='cuda') + @custom_bwd def backward(ctx, ans_grad): (ans,) = ctx.saved_tensors x_grad = ctx.alpha * ans_grad * (ans != 0) @@ -1512,7 +1512,7 @@ def SwooshRForward(x: Tensor): class ActivationDropoutAndLinearFunction(torch.autograd.Function): @staticmethod - @custom_fwd(device_type='cuda') + @custom_fwd def forward( ctx, x: Tensor, @@ -1551,7 +1551,7 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): return x @staticmethod - @custom_bwd(device_type='cuda') + @custom_bwd def backward(ctx, ans_grad: Tensor): saved = ctx.saved_tensors (x, weight, bias, dropout_mask) = saved