mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Change 0.025,0.05 to 0.01 in initializations
This commit is contained in:
parent
05e30d0c46
commit
11a04c50ae
@ -440,8 +440,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
||||
|
||||
def _reset_parameters(self) -> None:
|
||||
nn.init.normal_(self.pos_bias_u, std=0.05)
|
||||
nn.init.normal_(self.pos_bias_v, std=0.05)
|
||||
nn.init.normal_(self.pos_bias_u, std=0.01)
|
||||
nn.init.normal_(self.pos_bias_v, std=0.01)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -153,7 +153,7 @@ class ScaledLinear(nn.Linear):
|
||||
self._reset_parameters() # Overrides the reset_parameters in nn.Linear
|
||||
|
||||
def _reset_parameters(self):
|
||||
std = 0.025
|
||||
std = 0.01
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
@ -188,7 +188,7 @@ class ScaledConv1d(nn.Conv1d):
|
||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self):
|
||||
std = 0.025
|
||||
std = 0.01
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
@ -229,7 +229,7 @@ class ScaledConv2d(nn.Conv2d):
|
||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self):
|
||||
std = 0.025
|
||||
std = 0.01
|
||||
a = (3 ** 0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
@ -451,7 +451,7 @@ class ScaledEmbedding(nn.Module):
|
||||
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
std = 0.025
|
||||
std = 0.01
|
||||
nn.init.normal_(self.weight, std=std)
|
||||
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user