From d70f6e21f22785a6fda2141cd5be5a4b7530a7c2 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Wed, 26 Jul 2023 19:01:53 +0800 Subject: [PATCH] updated code for decoding --- .../beam_search.py | 6 ++-- .../alignment_attention_module.py | 30 +++++++++++-------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 6b9c6e7a1..16279d05d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -825,13 +825,15 @@ def deprecated_greedy_search_batch_for_cross_attn( # current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) current_encoder_out = model.joiner.label_level_am_attention( - encoder_out[:, : t + 1, :].unsqueeze(2), + encoder_out.unsqueeze(2), decoder_out.unsqueeze(2), - encoder_out_lens, + # encoder_out_lens, + None, ) logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1), + None, apply_attn=False, project_input=False, ) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py index 631fd462f..d4e6208b3 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py @@ -73,7 +73,7 @@ class CrossAttention(nn.Module): batch_size, lm_seq_len, am_seq_len, - ), f"{attn_weights.shape}" + ), f"{attn_weights.shape} {x.shape}" x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim) # print("projected x.shape", x.shape) @@ -406,19 +406,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if key_padding_mask is not None: # (batch, max_len) - key_padding_mask = ( - key_padding_mask.unsqueeze(1) - .expand( - key_padding_mask.shape[0], # b - self.prune_range, - key_padding_mask.shape[1], # l + if b_p_dim == key_padding_mask.shape[0] * self.prune_range: + key_padding_mask = ( + key_padding_mask.unsqueeze(1) + .expand( + key_padding_mask.shape[0], # b + self.prune_range, + key_padding_mask.shape[1], # l + ) + .reshape(b_p_dim, am_seq_len) + .unsqueeze(1) + .unsqueeze(0) ) - .reshape(b_p_dim, am_seq_len) - .unsqueeze(1) - .unsqueeze(0) - ) + # (1, b * p, 1, T) + else: + key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(0) + # (1, b, 1, T) # print(key_padding_mask.shape) - # (1, b * p, 1, T) attn_scores = attn_scores.masked_fill( key_padding_mask, @@ -492,7 +496,7 @@ class AlignmentAttentionModule(nn.Module): def forward( self, am_pruned: Tensor, lm_pruned: Tensor, lengths: torch.Tensor ) -> Tensor: - src_key_padding_mask = make_pad_mask(lengths) + src_key_padding_mask = make_pad_mask(lengths) if lengths is not None else None # (batch, max_len) if am_pruned.ndim == 4 and lm_pruned.ndim == 4: