From 0e9cad3f1f62abda43c6b218917525142c32b3d3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 16 Mar 2022 14:42:53 +0800 Subject: [PATCH] Modifying initialization from normal->uniform; add initial_scale when initializing --- .../ASR/conformer_ctc/subsampling.py | 17 +++++++++++------ .../ASR/transducer_stateless/conformer.py | 7 ++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 5e44c5b29..6cc90c8a1 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -441,15 +441,16 @@ class ScaledLinear(nn.Linear): self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): - nn.init.normal_(self.weight, std=0.05) + std = 0.05 + a = math.sqrt(3) * std + nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] + fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) - def get_weight(self): return self.weight * (self.weight_scale * self.scale_speed).exp() @@ -476,7 +477,9 @@ class ScaledConv1d(nn.Conv1d): self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - nn.init.normal_(self.weight, std=0.05) + std = 0.05 + a = math.sqrt(3) * std + nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() @@ -516,10 +519,12 @@ class ScaledConv2d(nn.Conv2d): self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - nn.init.normal_(self.weight, std=0.05) + std = 0.05 + a = math.sqrt(3) * std + nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) + fan_in = self.weight.shape[1] * self.weight[0][0].numel() scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index cafc04ed1..0832d9385 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module): DerivBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.feed_forward_macaron = nn.Sequential( @@ -170,7 +170,7 @@ class ConformerEncoderLayer(nn.Module): DerivBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) @@ -423,7 +423,7 @@ class RelPositionMultiheadAttention(nn.Module): ), "embed_dim must be divisible by num_heads" self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True) + self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25) # linear transformation for positional encoding. self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) @@ -885,6 +885,7 @@ class ConvolutionModule(nn.Module): stride=1, padding=0, bias=bias, + initial_scale=0.25 ) def forward(self, x: Tensor) -> Tensor: