diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index fb8123838..c22d6575e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -859,7 +859,7 @@ class RelPositionalEncoding(torch.nn.Module): class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding + r"""Multi-Head Attention layer with simplified relative position encoding See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" @@ -895,24 +895,7 @@ class RelPositionMultiheadAttention(nn.Module): ) # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) - self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) - self._reset_parameters() - - def _pos_bias_u(self): - return self.pos_bias_u * self.pos_bias_u_scale.exp() - - def _pos_bias_v(self): - return self.pos_bias_v * self.pos_bias_v_scale.exp() - - def _reset_parameters(self) -> None: - nn.init.normal_(self.pos_bias_u, std=0.01) - nn.init.normal_(self.pos_bias_v, std=0.01) + self.linear_pos = ScaledLinear(embed_dim, num_heads, bias=True) def forward( self, @@ -1217,35 +1200,25 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask.size(1), src_len ) - q = q.transpose(0, 1) # (batch, time1, head, d_k) + q = q.permute(1, 2, 0, 3) # (batch, head, time1, d_k) pos_emb_bsz = pos_emb.size(0) assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) - p = p.permute(0, 2, 3, 1) - - q_with_bias_u = (q + self._pos_bias_u()).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self._pos_bias_v()).transpose( - 1, 2 - ) # (batch, head, time1, d_k) # compute attention score # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) + matrix_ac = torch.matmul(q, k) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd, left_context) + pos_emb = self.linear_pos(pos_emb) # (1, 2*time1-1, head) + matrix_bd = ( + pos_emb.transpose(1, 2).unsqueeze(2).repeat(1, 1, tgt_len, 1) + ) # (1, head, time1, 2*time1-1) + matrix_bd = self.rel_shift( + matrix_bd, left_context + ) # (1, head, time1, time2) attn_output_weights = ( matrix_ac + matrix_bd