diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 29685278b..ca7e4c200 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1044,8 +1044,19 @@ class RelPositionMultiheadAttention(nn.Module): self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05) - # the following are for diagnosics only, see --print-diagnostics option. - # they only copy their inputs. + # this is to stop a failure mode where the output gets very small and is + # dominated by the mean (the min_positive and max_positive will stop the mean + # being much larger than the variance). Make min_abs very small because even for normal, + # functional self_attn layers, the output rms can be very small. + self.out_balancer = ActivationBalancer(embed_dim, + channel_dim=-1, + min_positive=0.33, + max_positive=0.66, + min_abs=0.005, max_abs=1.0, + min_prob=0.05) + + + # the following are for diagnosics only, see --print-diagnostics option self.copy_pos_query = Identity() self.copy_query = Identity() @@ -1061,6 +1072,13 @@ class RelPositionMultiheadAttention(nn.Module): whitening_limit=2.0, prob=(0.025, 0.25), grad_scale=0.025) + self.out_balancer2 = ActivationBalancer(embed_dim, + channel_dim=-1, + min_positive=0.33, + max_positive=0.66, + min_abs=0.005, max_abs=1.0, + min_prob=0.05) + def forward( @@ -1113,8 +1131,6 @@ class RelPositionMultiheadAttention(nn.Module): self.attention_dim, self.num_heads, self.dropout, - self.out_proj.weight, - self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, attn_mask=attn_mask, @@ -1129,8 +1145,6 @@ class RelPositionMultiheadAttention(nn.Module): attention_dim: int, num_heads: int, dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, training: bool = True, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, @@ -1142,7 +1156,6 @@ class RelPositionMultiheadAttention(nn.Module): attention_dim: dimension inside attention mechanism num_heads: parallel attention heads. dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. training: apply dropout if is ``True``. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. This is an binary mask. When the value is True, @@ -1349,9 +1362,8 @@ class RelPositionMultiheadAttention(nn.Module): .contiguous() .view(seq_len, bsz, attention_dim // 2) ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) + attn_output = self.out_proj(attn_output) + attn_output = self.out_balancer(attn_output) return attn_output, attn_output_weights @@ -1391,7 +1403,9 @@ class RelPositionMultiheadAttention(nn.Module): .view(seq_len, bsz, self.attention_dim // 2) ) # returned value is of shape (seq_len, bsz, embed_dim), like x. - return self.out_proj2(attn_output) + attn_output = self.out_proj2(attn_output) + attn_output = self.out_balancer2(attn_output) + return attn_output def _print_attn_stats(