mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Change initial bias scales from 0.1 to 0.2
This commit is contained in:
parent
435b073979
commit
5d57dd3930
@ -457,8 +457,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
return self.pos_bias_v * self.pos_bias_v_scale.exp()
|
||||||
|
|
||||||
def _reset_parameters(self) -> None:
|
def _reset_parameters(self) -> None:
|
||||||
nn.init.uniform_(self.pos_bias_u, -0.1, 0.1)
|
nn.init.uniform_(self.pos_bias_u, -0.2, 0.2)
|
||||||
nn.init.uniform_(self.pos_bias_v, -0.1, 0.1)
|
nn.init.uniform_(self.pos_bias_v, -0.2, 0.2)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -179,7 +179,7 @@ class ScaledLinear(nn.Linear):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[:] *= initial_scale
|
self.weight[:] *= initial_scale
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
torch.nn.init.uniform_(self.bias, -0.1, 0.1)
|
torch.nn.init.uniform_(self.bias, -0.2, 0.2)
|
||||||
|
|
||||||
def get_weight(self): # not needed any more but kept for back compatibility
|
def get_weight(self): # not needed any more but kept for back compatibility
|
||||||
return self.weight
|
return self.weight
|
||||||
@ -201,7 +201,7 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[:] *= initial_scale
|
self.weight[:] *= initial_scale
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
torch.nn.init.uniform_(self.bias, -0.1, 0.1)
|
torch.nn.init.uniform_(self.bias, -0.2, 0.2)
|
||||||
|
|
||||||
def get_weight(self): # TODO: delete
|
def get_weight(self): # TODO: delete
|
||||||
return self.weight
|
return self.weight
|
||||||
@ -222,7 +222,7 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[:] *= initial_scale
|
self.weight[:] *= initial_scale
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
torch.nn.init.uniform_(self.bias, -0.1, 0.1)
|
torch.nn.init.uniform_(self.bias, -0.2, 0.2)
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
return self.weight
|
return self.weight
|
||||||
|
Loading…
x
Reference in New Issue
Block a user