diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 81d7708f9..c2e505944 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -449,6 +449,12 @@ class RelPositionMultiheadAttention(nn.Module): self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + # Before applying the softmax to the matrix from the dot-product, + # we multiply that matrix with this one + self.pre_softmax_param = nn.Parameter( + torch.Tensor(num_heads, num_heads) + ) + self._reset_parameters() def _reset_parameters(self) -> None: @@ -459,6 +465,9 @@ class RelPositionMultiheadAttention(nn.Module): nn.init.xavier_uniform_(self.pos_bias_u) nn.init.xavier_uniform_(self.pos_bias_v) + stdv = 1.0 / math.sqrt(self.num_heads) + nn.init.normal_(self.pre_softmax_param, mean=0, std=stdv) + def forward( self, query: Tensor, @@ -780,6 +789,19 @@ class RelPositionMultiheadAttention(nn.Module): bsz * num_heads, tgt_len, -1 ) + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, -1 + ).permute(0, 2, 3, 1) + # now attn_output_weights is of shape (bsz, tgt_len, src_len, num_heads) + + attn_output_weights = torch.matmul( + attn_output_weights, self.pre_softmax_param + ) + + attn_output_weights = attn_output_weights.permute(0, 3, 1, 2).reshape( + bsz * num_heads, tgt_len, -1 + ) + assert list(attn_output_weights.size()) == [ bsz * num_heads, tgt_len,