diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index eae046e67..77a67782a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -465,6 +465,8 @@ class RelPositionMultiheadAttention(nn.Module): self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) self.in_balancer = ActivationBalancer(channel_dim=-1, max_abs=5.0) + self.proj_balancer = ActivationBalancer(channel_dim=-1, min_positive=0.0, + max_positive=1.0, max_abs=10.0) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.5 ) @@ -774,7 +776,7 @@ class RelPositionMultiheadAttention(nn.Module): pos_emb_bsz = pos_emb.size(0) assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = self.proj_balancer(self.linear_pos(pos_emb)).view(pos_emb_bsz, -1, num_heads, head_dim) p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) q_with_bias_u = (q + self.pos_bias_u).transpose(