mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Revert scaling, scale only grad.
This commit is contained in:
parent
b93cf0676a
commit
178eca1c0e
@ -900,6 +900,26 @@ def with_loss(x, y):
|
|||||||
return WithLoss.apply(x, y)
|
return WithLoss.apply(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaleGradFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
|
||||||
|
ctx.alpha = alpha
|
||||||
|
return x
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad: Tensor):
|
||||||
|
return grad * ctx.alpha, None
|
||||||
|
|
||||||
|
def scale_grad(x: Tensor, alpha: float):
|
||||||
|
return ScaleGradFunction.apply(x, alpha)
|
||||||
|
|
||||||
|
class ScaleGrad(nn.Module):
|
||||||
|
def __init__(self, alpha: float):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return scale_grad(x, self.alpha)
|
||||||
|
|
||||||
class LimitParamValue(torch.autograd.Function):
|
class LimitParamValue(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, min: float, max: float):
|
def forward(ctx, x: Tensor, min: float, max: float):
|
||||||
|
|||||||
@ -43,6 +43,7 @@ from scaling import (
|
|||||||
ScheduledFloat,
|
ScheduledFloat,
|
||||||
FloatLike,
|
FloatLike,
|
||||||
limit_param_value,
|
limit_param_value,
|
||||||
|
ScaleGrad,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
@ -1719,25 +1720,22 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
# a too-large gradient).
|
# a too-large gradient).
|
||||||
|
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
ScalarMultiply(0.1),
|
nn.Conv2d(
|
||||||
ScaledConv2d(
|
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=layer1_channels,
|
out_channels=layer1_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=(0, 1), # (time, freq)
|
padding=(0, 1), # (time, freq)
|
||||||
initial_scale=5.0,
|
|
||||||
),
|
),
|
||||||
ScalarMultiply(0.25),
|
ScaleGrad(0.1),
|
||||||
ActivationBalancer(layer1_channels,
|
ActivationBalancer(layer1_channels,
|
||||||
channel_dim=1),
|
channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
ScaledConv2d(
|
nn.Conv2d(
|
||||||
in_channels=layer1_channels,
|
in_channels=layer1_channels,
|
||||||
out_channels=layer2_channels,
|
out_channels=layer2_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
padding=0,
|
padding=0,
|
||||||
initial_scale=5.0,
|
|
||||||
),
|
),
|
||||||
ActivationBalancer(layer2_channels,
|
ActivationBalancer(layer2_channels,
|
||||||
channel_dim=1),
|
channel_dim=1),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user