mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove small_conv_module and make nonlin_attention_module slightly wider
This commit is contained in:
parent
80b2c751e3
commit
659ca97001
@ -439,7 +439,6 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# to work correctly.
|
# to work correctly.
|
||||||
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
layer_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.05), default=0),
|
||||||
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||||
small_conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.2), (16000, 0.1), default=0),
|
|
||||||
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||||
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
||||||
ff2_skip_rate: FloatLike = 0.01,
|
ff2_skip_rate: FloatLike = 0.01,
|
||||||
@ -457,10 +456,6 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
# contributing too much early on.
|
# contributing too much early on.
|
||||||
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
|
self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
|
||||||
|
|
||||||
# skip rate for small_conv_module; it is fairly high and remains nonzero
|
|
||||||
# because we don't want this submodule to contribute too much.
|
|
||||||
self.small_conv_skip_rate = copy.deepcopy(small_conv_skip_rate)
|
|
||||||
|
|
||||||
# ff2_skip_rate is to prevent the ff2 module from having output that's too big
|
# ff2_skip_rate is to prevent the ff2 module from having output that's too big
|
||||||
# compared to its residual.
|
# compared to its residual.
|
||||||
self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
|
self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
|
||||||
@ -489,9 +484,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
dropout)
|
dropout)
|
||||||
|
|
||||||
self.nonlin_attention = NonlinAttention(embed_dim,
|
self.nonlin_attention = NonlinAttention(embed_dim,
|
||||||
hidden_channels=embed_dim // 2)
|
hidden_channels=3 * embed_dim // 4)
|
||||||
|
|
||||||
self.small_conv_module = SmallConvolutionModule(embed_dim)
|
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(embed_dim,
|
self.conv_module = ConvolutionModule(embed_dim,
|
||||||
cnn_module_kernel)
|
cnn_module_kernel)
|
||||||
@ -637,9 +630,6 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
selected_attn_weights[0:1]))
|
selected_attn_weights[0:1]))
|
||||||
|
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() >= float(self.small_conv_skip_rate):
|
|
||||||
src = src + self.small_conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
|
||||||
|
|
||||||
src = src + self.feed_forward1(src)
|
src = src + self.feed_forward1(src)
|
||||||
|
|
||||||
## pooling module
|
## pooling module
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user