diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 1fc46259b..34e056955 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -558,6 +558,32 @@ def ScaledConv1d(*args, return ans +def ScaledConv2d(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Conv2d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv2d(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + return ans + class ActivationBalancer(torch.nn.Module): """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 14eb2ca94..c34f465af 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -33,6 +33,7 @@ from scaling import ( SwooshR, TanSwish, ScaledConv1d, + ScaledConv2d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. LinearWithAuxLoss, Whiten, @@ -1719,22 +1720,24 @@ class Conv2dSubsampling(nn.Module): self.conv = nn.Sequential( ScalarMultiply(0.1), - nn.Conv2d( + ScaledConv2d( in_channels=1, out_channels=layer1_channels, kernel_size=3, padding=(0, 1), # (time, freq) + initial_scale=5.0, ), ScalarMultiply(0.25), ActivationBalancer(layer1_channels, channel_dim=1), DoubleSwish(), - nn.Conv2d( + ScaledConv2d( in_channels=layer1_channels, out_channels=layer2_channels, kernel_size=3, stride=2, padding=0, + initial_scale=5.0, ), ActivationBalancer(layer2_channels, channel_dim=1),