mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
Modifying initialization from normal->uniform; add initial_scale when initializing
This commit is contained in:
parent
00be56c7a0
commit
0e9cad3f1f
@ -441,15 +441,16 @@ class ScaledLinear(nn.Linear):
|
|||||||
self._reset_parameters() # Overrides the reset_parameters in nn.Linear
|
self._reset_parameters() # Overrides the reset_parameters in nn.Linear
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self):
|
||||||
nn.init.normal_(self.weight, std=0.05)
|
std = 0.05
|
||||||
|
a = math.sqrt(3) * std
|
||||||
|
nn.init.uniform_(self.weight, -a, a)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
nn.init.constant_(self.bias, 0.0)
|
nn.init.constant_(self.bias, 0.0)
|
||||||
fan_in = self.weight.shape[1]
|
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
|
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
|
||||||
|
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
return self.weight * (self.weight_scale * self.scale_speed).exp()
|
||||||
|
|
||||||
@ -476,7 +477,9 @@ class ScaledConv1d(nn.Conv1d):
|
|||||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
self._reset_parameters() # Overrides the reset_parameters in base class
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self):
|
||||||
nn.init.normal_(self.weight, std=0.05)
|
std = 0.05
|
||||||
|
a = math.sqrt(3) * std
|
||||||
|
nn.init.uniform_(self.weight, -a, a)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
nn.init.constant_(self.bias, 0.0)
|
nn.init.constant_(self.bias, 0.0)
|
||||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||||
@ -516,10 +519,12 @@ class ScaledConv2d(nn.Conv2d):
|
|||||||
self._reset_parameters() # Overrides the reset_parameters in base class
|
self._reset_parameters() # Overrides the reset_parameters in base class
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self):
|
||||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
std = 0.05
|
||||||
nn.init.normal_(self.weight, std=0.05)
|
a = math.sqrt(3) * std
|
||||||
|
nn.init.uniform_(self.weight, -a, a)
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
nn.init.constant_(self.bias, 0.0)
|
nn.init.constant_(self.bias, 0.0)
|
||||||
|
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
|
self.weight_scale += (torch.tensor(scale / 0.05).log() / self.scale_speed)
|
||||||
|
@ -162,7 +162,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
DerivBalancer(channel_dim=-1),
|
DerivBalancer(channel_dim=-1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model),
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
@ -170,7 +170,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
DerivBalancer(channel_dim=-1),
|
DerivBalancer(channel_dim=-1),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model),
|
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
@ -423,7 +423,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True)
|
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25)
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
# linear transformation for positional encoding.
|
||||||
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
||||||
@ -885,6 +885,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
initial_scale=0.25
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user