Set also scale for embedding to 0.025.

This commit is contained in:
Daniel Povey 2022-03-18 21:30:05 +08:00
parent 188eada7ac
commit 8cff994cd7

View File

@ -451,8 +451,9 @@ class ScaledEmbedding(nn.Module):
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.05)
nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log())
std = 0.025
nn.init.normal_(self.weight, std=std)
nn.init.constant_(self.scale, torch.tensor(1.0/std).log())
if self.padding_idx is not None:
with torch.no_grad():