diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index e22db4d34..5183f2867 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -440,7 +440,7 @@ class ZipformerEncoderLayer(nn.Module): cnn_module_kernel) - self.attention_squeeze = AttentionSqueeze(embed_dim) + self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) self.norm_final = BasicNorm(embed_dim) @@ -1323,11 +1323,12 @@ class AttentionSqueeze(nn.Module): """ def __init__(self, embed_dim: int, + hidden_dim: int, bottleneck_dim: int = 16): super().__init__() self.bottleneck_dim = bottleneck_dim - self.in_proj = LinearWithAuxLoss(embed_dim, embed_dim, + self.in_proj = LinearWithAuxLoss(embed_dim, hidden_dim, bias=False, aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_in()) @@ -1355,13 +1356,13 @@ class AttentionSqueeze(nn.Module): # Make them run with very low probability, since only a small # application of these balancers should be enough to stop such "drift". self.scale_balancer = ActivationBalancer( - embed_dim, channel_dim=-1, + hidden_dim, channel_dim=-1, min_positive=0.2, max_positive=0.8, min_abs=0.2, max_abs=1.0, min_prob=0.05, ) self.activation_balancer = ActivationBalancer( - embed_dim, channel_dim=-1, + hidden_dim, channel_dim=-1, min_positive=0.2, max_positive=0.8, min_abs=0.2, max_abs=1.0, min_prob=0.05, @@ -1372,10 +1373,11 @@ class AttentionSqueeze(nn.Module): grad_scale=0.01) - self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim) + self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, hidden_dim) - self.out_proj = LinearWithAuxLoss(embed_dim, embed_dim, - aux_grad_scale=_aux_grad_scale(), prob=_aux_grad_prob_out(), + self.out_proj = LinearWithAuxLoss(hidden_dim, embed_dim, + aux_grad_scale=_aux_grad_scale(), + prob=_aux_grad_prob_out(), bias=False, initial_scale=0.05) def forward(self,