diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 34e056955..a2a125876 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -900,6 +900,26 @@ def with_loss(x, y): return WithLoss.apply(x, y) +class ScaleGradFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, alpha: float) -> Tensor: + ctx.alpha = alpha + return x + @staticmethod + def backward(ctx, grad: Tensor): + return grad * ctx.alpha, None + +def scale_grad(x: Tensor, alpha: float): + return ScaleGradFunction.apply(x, alpha) + +class ScaleGrad(nn.Module): + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + + def forward(self, x: Tensor) -> Tensor: + return scale_grad(x, self.alpha) + class LimitParamValue(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, min: float, max: float): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c34f465af..56fefbd90 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -43,6 +43,7 @@ from scaling import ( ScheduledFloat, FloatLike, limit_param_value, + ScaleGrad, ) from torch import Tensor, nn @@ -1719,25 +1720,22 @@ class Conv2dSubsampling(nn.Module): # a too-large gradient). self.conv = nn.Sequential( - ScalarMultiply(0.1), - ScaledConv2d( + nn.Conv2d( in_channels=1, out_channels=layer1_channels, kernel_size=3, padding=(0, 1), # (time, freq) - initial_scale=5.0, ), - ScalarMultiply(0.25), + ScaleGrad(0.1), ActivationBalancer(layer1_channels, channel_dim=1), DoubleSwish(), - ScaledConv2d( + nn.Conv2d( in_channels=layer1_channels, out_channels=layer2_channels, kernel_size=3, stride=2, padding=0, - initial_scale=5.0, ), ActivationBalancer(layer2_channels, channel_dim=1),