Revert scaling, scale only grad.

This commit is contained in:
Daniel Povey 2022-12-05 17:53:23 +08:00
parent b93cf0676a
commit 178eca1c0e
2 changed files with 24 additions and 6 deletions

View File

@ -900,6 +900,26 @@ def with_loss(x, y):
return WithLoss.apply(x, y)
class ScaleGradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
ctx.alpha = alpha
return x
@staticmethod
def backward(ctx, grad: Tensor):
return grad * ctx.alpha, None
def scale_grad(x: Tensor, alpha: float):
return ScaleGradFunction.apply(x, alpha)
class ScaleGrad(nn.Module):
def __init__(self, alpha: float):
super().__init__()
self.alpha = alpha
def forward(self, x: Tensor) -> Tensor:
return scale_grad(x, self.alpha)
class LimitParamValue(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, min: float, max: float):

View File

@ -43,6 +43,7 @@ from scaling import (
ScheduledFloat,
FloatLike,
limit_param_value,
ScaleGrad,
)
from torch import Tensor, nn
@ -1719,25 +1720,22 @@ class Conv2dSubsampling(nn.Module):
# a too-large gradient).
self.conv = nn.Sequential(
ScalarMultiply(0.1),
ScaledConv2d(
nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=(0, 1), # (time, freq)
initial_scale=5.0,
),
ScalarMultiply(0.25),
ScaleGrad(0.1),
ActivationBalancer(layer1_channels,
channel_dim=1),
DoubleSwish(),
ScaledConv2d(
nn.Conv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
padding=0,
initial_scale=5.0,
),
ActivationBalancer(layer2_channels,
channel_dim=1),