mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Revert scaling, scale only grad.
This commit is contained in:
parent
b93cf0676a
commit
178eca1c0e
@ -900,6 +900,26 @@ def with_loss(x, y):
|
||||
return WithLoss.apply(x, y)
|
||||
|
||||
|
||||
class ScaleGradFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
|
||||
ctx.alpha = alpha
|
||||
return x
|
||||
@staticmethod
|
||||
def backward(ctx, grad: Tensor):
|
||||
return grad * ctx.alpha, None
|
||||
|
||||
def scale_grad(x: Tensor, alpha: float):
|
||||
return ScaleGradFunction.apply(x, alpha)
|
||||
|
||||
class ScaleGrad(nn.Module):
|
||||
def __init__(self, alpha: float):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return scale_grad(x, self.alpha)
|
||||
|
||||
class LimitParamValue(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, min: float, max: float):
|
||||
|
||||
@ -43,6 +43,7 @@ from scaling import (
|
||||
ScheduledFloat,
|
||||
FloatLike,
|
||||
limit_param_value,
|
||||
ScaleGrad,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -1719,25 +1720,22 @@ class Conv2dSubsampling(nn.Module):
|
||||
# a too-large gradient).
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
ScalarMultiply(0.1),
|
||||
ScaledConv2d(
|
||||
nn.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=(0, 1), # (time, freq)
|
||||
initial_scale=5.0,
|
||||
),
|
||||
ScalarMultiply(0.25),
|
||||
ScaleGrad(0.1),
|
||||
ActivationBalancer(layer1_channels,
|
||||
channel_dim=1),
|
||||
DoubleSwish(),
|
||||
ScaledConv2d(
|
||||
nn.Conv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0,
|
||||
initial_scale=5.0,
|
||||
),
|
||||
ActivationBalancer(layer2_channels,
|
||||
channel_dim=1),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user