diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index fe19ed8aa..cd1c40294 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -421,8 +421,7 @@ class ZipformerEncoderLayer(nn.Module): cnn_module_kernel) - self.attention_squeeze1 = AttentionSqueeze(embed_dim) - self.attention_squeeze2 = AttentionSqueeze(embed_dim) + self.attention_squeeze = AttentionSqueeze(embed_dim) self.norm_final = BasicNorm(embed_dim) @@ -510,7 +509,7 @@ class ZipformerEncoderLayer(nn.Module): # pooling module if torch.jit.is_scripting() or use_self_attn: - src = src + self.attention_squeeze1(src, first_attn_weights[1:2]) + src = src + self.attention_squeeze(src, first_attn_weights[1:2]) if torch.jit.is_scripting() or use_self_attn: src = src + self.self_attn( @@ -521,9 +520,6 @@ class ZipformerEncoderLayer(nn.Module): src = src + self.feed_forward2(src) - # pooling module - if torch.jit.is_scripting() or use_self_attn: - src = src + self.attention_squeeze2(src, first_attn_weights[2:3]) src = self.norm_final(self.balancer(src)) @@ -1432,11 +1428,11 @@ class NonlinAttentionModule(nn.Module): ) -> None: super().__init__() - self.in_proj = nn.Linear(channels, 2 * channels, bias=True) + self.in_proj = nn.Linear(channels, channels, bias=True) # balancer that goes before the sigmoid. self.balancer = ActivationBalancer( - channels, channel_dim=-1, + channels // 2, channel_dim=-1, min_positive=0.05, max_positive=1.0, min_abs=0.2, max_abs=ScheduledFloat((0.0, 2.0), (4000.0, 10.0), @@ -1445,7 +1441,7 @@ class NonlinAttentionModule(nn.Module): self.sigmoid = nn.Sigmoid() self.activation = Identity() # for diagnostics. - self.out_proj = ScaledLinear(channels, channels, + self.out_proj = ScaledLinear(channels // 2, channels, bias=True, initial_scale=0.05) @@ -1532,7 +1528,10 @@ class ConvolutionModule(nn.Module): # the correct range. self.deriv_balancer1 = ActivationBalancer( 2 * channels, channel_dim=-1, - max_abs=10.0, min_positive=0.05, max_positive=1.0 + max_abs=ScheduledFloat((0.0, 2.0), + (4000.0, 10.0), + default=1.0), + min_positive=0.05, max_positive=1.0 ) self.pre_sigmoid = Identity() # before sigmoid; for diagnostics.