mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
remove all lr_scales, set layer3_channels=128, change the position of feed_forward1
This commit is contained in:
parent
2cd1933873
commit
9291a39f58
@ -218,9 +218,6 @@ class Zipformer2(EncoderInterface):
|
||||
downsample=downsampling_factor[i],
|
||||
dropout=dropout,
|
||||
)
|
||||
# we are adding a new attribute here.
|
||||
# this will be interpreted by get_named_parameter_groups_with_lrs().
|
||||
encoder.lr_scale = downsampling_factor[i] ** -0.33
|
||||
|
||||
encoders.append(encoder)
|
||||
|
||||
@ -713,6 +710,8 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||
|
||||
if True:
|
||||
@ -733,8 +732,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
|
||||
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
|
||||
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
self_attn = self.self_attn1(
|
||||
src, attn_weights)
|
||||
|
||||
@ -1200,7 +1197,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
(4000.0, 0.0))
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lr_scale = 0.9
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.query_head_dim = query_head_dim
|
||||
@ -1518,8 +1514,6 @@ class NonlinAttention(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.lr_scale = 0.95
|
||||
|
||||
self.hidden_channels = hidden_channels
|
||||
|
||||
self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
|
||||
@ -1633,7 +1627,6 @@ class ConvolutionModule(nn.Module):
|
||||
)
|
||||
# the gradients on in_proj are a little noisy, likely to do with the
|
||||
# sigmoid in glu.
|
||||
self.in_proj.lr_scale = 0.9
|
||||
|
||||
# after in_proj we put x through a gated linear unit (nn.functional.glu).
|
||||
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
|
||||
@ -1862,7 +1855,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 64,
|
||||
layer3_channels: int = 128,
|
||||
dropout: FloatLike = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user