mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Initialize Conv2dSubsampling with scale.
This commit is contained in:
parent
7999dd0dbe
commit
b93cf0676a
@ -558,6 +558,32 @@ def ScaledConv1d(*args,
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def ScaledConv2d(*args,
|
||||||
|
initial_scale: float = 1.0,
|
||||||
|
**kwargs ) -> nn.Conv2d:
|
||||||
|
"""
|
||||||
|
Behaves like a constructor of a modified version of nn.Conv1d
|
||||||
|
that gives an easy way to set the default initial parameter scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Accepts the standard args and kwargs that nn.Linear accepts
|
||||||
|
e.g. in_features, out_features, bias=False.
|
||||||
|
|
||||||
|
initial_scale: you can override this if you want to increase
|
||||||
|
or decrease the initial magnitude of the module's output
|
||||||
|
(affects the initialization of weight_scale and bias_scale).
|
||||||
|
Another option, if you want to do something like this, is
|
||||||
|
to re-initialize the parameters.
|
||||||
|
"""
|
||||||
|
ans = nn.Conv2d(*args, **kwargs)
|
||||||
|
with torch.no_grad():
|
||||||
|
ans.weight[:] *= initial_scale
|
||||||
|
if ans.bias is not None:
|
||||||
|
torch.nn.init.uniform_(ans.bias,
|
||||||
|
-0.1 * initial_scale,
|
||||||
|
0.1 * initial_scale)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
class ActivationBalancer(torch.nn.Module):
|
class ActivationBalancer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -33,6 +33,7 @@ from scaling import (
|
|||||||
SwooshR,
|
SwooshR,
|
||||||
TanSwish,
|
TanSwish,
|
||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
|
ScaledConv2d,
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
LinearWithAuxLoss,
|
LinearWithAuxLoss,
|
||||||
Whiten,
|
Whiten,
|
||||||
@ -1719,22 +1720,24 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
ScalarMultiply(0.1),
|
ScalarMultiply(0.1),
|
||||||
nn.Conv2d(
|
ScaledConv2d(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=layer1_channels,
|
out_channels=layer1_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=(0, 1), # (time, freq)
|
padding=(0, 1), # (time, freq)
|
||||||
|
initial_scale=5.0,
|
||||||
),
|
),
|
||||||
ScalarMultiply(0.25),
|
ScalarMultiply(0.25),
|
||||||
ActivationBalancer(layer1_channels,
|
ActivationBalancer(layer1_channels,
|
||||||
channel_dim=1),
|
channel_dim=1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Conv2d(
|
ScaledConv2d(
|
||||||
in_channels=layer1_channels,
|
in_channels=layer1_channels,
|
||||||
out_channels=layer2_channels,
|
out_channels=layer2_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=2,
|
stride=2,
|
||||||
padding=0,
|
padding=0,
|
||||||
|
initial_scale=5.0,
|
||||||
),
|
),
|
||||||
ActivationBalancer(layer2_channels,
|
ActivationBalancer(layer2_channels,
|
||||||
channel_dim=1),
|
channel_dim=1),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user