diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py index d07d2eaee..33ca51743 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/conformer.py @@ -457,8 +457,8 @@ class RelPositionMultiheadAttention(nn.Module): return self.pos_bias_v * self.pos_bias_v_scale.exp() def _reset_parameters(self) -> None: - nn.init.uniform_(self.pos_bias_u, -0.1, 0.1) - nn.init.uniform_(self.pos_bias_v, -0.1, 0.1) + nn.init.uniform_(self.pos_bias_u, -0.2, 0.2) + nn.init.uniform_(self.pos_bias_v, -0.2, 0.2) def forward( self, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling.py index 6556edd3a..f0a1ec0ca 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/scaling.py @@ -179,7 +179,7 @@ class ScaledLinear(nn.Linear): with torch.no_grad(): self.weight[:] *= initial_scale if self.bias is not None: - torch.nn.init.uniform_(self.bias, -0.1, 0.1) + torch.nn.init.uniform_(self.bias, -0.2, 0.2) def get_weight(self): # not needed any more but kept for back compatibility return self.weight @@ -201,7 +201,7 @@ class ScaledConv1d(nn.Conv1d): with torch.no_grad(): self.weight[:] *= initial_scale if self.bias is not None: - torch.nn.init.uniform_(self.bias, -0.1, 0.1) + torch.nn.init.uniform_(self.bias, -0.2, 0.2) def get_weight(self): # TODO: delete return self.weight @@ -222,7 +222,7 @@ class ScaledConv2d(nn.Conv2d): with torch.no_grad(): self.weight[:] *= initial_scale if self.bias is not None: - torch.nn.init.uniform_(self.bias, -0.1, 0.1) + torch.nn.init.uniform_(self.bias, -0.2, 0.2) def get_weight(self): return self.weight