mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Modifying initialization from normal->uniform; add initial_scale when initializing
This commit is contained in:
parent
00be56c7a0
commit
0e9cad3f1f
@ -441,15 +441,16 @@ class ScaledLinear(nn.Linear):
|
||||
self._reset_parameters() # Overrides the reset_parameters in nn.Linear
|
||||
|
||||
def _reset_parameters(self):
|
||||
nn.init.normal_(self.weight, std=0.05)
|
||||
std = 0.05
|
||||
a = math.sqrt(3) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1]
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
|
||||
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
||||
|
||||
@ -476,7 +477,9 @@ class ScaledConv1d(nn.Conv1d):
|
||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self):
|
||||
nn.init.normal_(self.weight, std=0.05)
|
||||
std = 0.05
|
||||
a = math.sqrt(3) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
@ -516,10 +519,12 @@ class ScaledConv2d(nn.Conv2d):
|
||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self):
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
nn.init.normal_(self.weight, std=0.05)
|
||||
std = 0.05
|
||||
a = math.sqrt(3) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
|
||||
|
@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
DerivBalancer(channel_dim=-1),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
)
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
@ -170,7 +170,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
DerivBalancer(channel_dim=-1),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model),
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
@ -423,7 +423,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
||||
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True)
|
||||
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25)
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
||||
@ -885,6 +885,7 @@ class ConvolutionModule(nn.Module):
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
initial_scale=0.25
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
Loading…
x
Reference in New Issue
Block a user