Initialize Conv2dSubsampling with scale.

This commit is contained in:
Daniel Povey 2022-12-05 17:31:56 +08:00
parent 7999dd0dbe
commit b93cf0676a
2 changed files with 31 additions and 2 deletions

View File

@ -558,6 +558,32 @@ def ScaledConv1d(*args,
return ans return ans
def ScaledConv2d(*args,
initial_scale: float = 1.0,
**kwargs ) -> nn.Conv2d:
"""
Behaves like a constructor of a modified version of nn.Conv1d
that gives an easy way to set the default initial parameter scale.
Args:
Accepts the standard args and kwargs that nn.Linear accepts
e.g. in_features, out_features, bias=False.
initial_scale: you can override this if you want to increase
or decrease the initial magnitude of the module's output
(affects the initialization of weight_scale and bias_scale).
Another option, if you want to do something like this, is
to re-initialize the parameters.
"""
ans = nn.Conv2d(*args, **kwargs)
with torch.no_grad():
ans.weight[:] *= initial_scale
if ans.bias is not None:
torch.nn.init.uniform_(ans.bias,
-0.1 * initial_scale,
0.1 * initial_scale)
return ans
class ActivationBalancer(torch.nn.Module): class ActivationBalancer(torch.nn.Module):
""" """

View File

@ -33,6 +33,7 @@ from scaling import (
SwooshR, SwooshR,
TanSwish, TanSwish,
ScaledConv1d, ScaledConv1d,
ScaledConv2d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
LinearWithAuxLoss, LinearWithAuxLoss,
Whiten, Whiten,
@ -1719,22 +1720,24 @@ class Conv2dSubsampling(nn.Module):
self.conv = nn.Sequential( self.conv = nn.Sequential(
ScalarMultiply(0.1), ScalarMultiply(0.1),
nn.Conv2d( ScaledConv2d(
in_channels=1, in_channels=1,
out_channels=layer1_channels, out_channels=layer1_channels,
kernel_size=3, kernel_size=3,
padding=(0, 1), # (time, freq) padding=(0, 1), # (time, freq)
initial_scale=5.0,
), ),
ScalarMultiply(0.25), ScalarMultiply(0.25),
ActivationBalancer(layer1_channels, ActivationBalancer(layer1_channels,
channel_dim=1), channel_dim=1),
DoubleSwish(), DoubleSwish(),
nn.Conv2d( ScaledConv2d(
in_channels=layer1_channels, in_channels=layer1_channels,
out_channels=layer2_channels, out_channels=layer2_channels,
kernel_size=3, kernel_size=3,
stride=2, stride=2,
padding=0, padding=0,
initial_scale=5.0,
), ),
ActivationBalancer(layer2_channels, ActivationBalancer(layer2_channels,
channel_dim=1), channel_dim=1),