bug fix re sqrt

This commit is contained in:
Daniel Povey 2022-03-16 14:55:17 +08:00
parent 0e9cad3f1f
commit 6561743d7b

View File

@ -442,7 +442,7 @@ class ScaledLinear(nn.Linear):
def _reset_parameters(self):
std = 0.05
a = math.sqrt(3) * std
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
@ -478,7 +478,7 @@ class ScaledConv1d(nn.Conv1d):
def _reset_parameters(self):
std = 0.05
a = math.sqrt(3) * std
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
@ -520,7 +520,7 @@ class ScaledConv2d(nn.Conv2d):
def _reset_parameters(self):
std = 0.05
a = math.sqrt(3) * std
a = (3 ** 0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)