diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py index 51a09db5a..41d8a9744 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py @@ -120,7 +120,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim) am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim) pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions + key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions that are True in this mask will be ignored as sources in the attention weighting. attn_mask: mask of shape (seq_len, seq_len) or (seq_len, batch_size * prune_range, batch_size * prune_range), @@ -333,13 +333,15 @@ class AlignmentAttentionModule(nn.Module): pos_emb = self.pos_encode(merged_am_pruned) - attn_weights = self.cross_attn_weights(merged_lm_pruned, merged_am_pruned, pos_emb) + attn_weights = self.cross_attn_weights( + merged_lm_pruned, merged_am_pruned, pos_emb + ) label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights) # (T, batch_size * prune_range, encoder_dim) - return label_level_am_representation \ - .reshape(T, batch_size, prune_range, encoder_dim) \ - .permute(1, 0, 2, 3) + return label_level_am_representation.reshape( + T, batch_size, prune_range, encoder_dim + ).permute(1, 0, 2, 3) if __name__ == "__main__": diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module_debug.py b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module_debug.py index 12f66447d..5a9a425cb 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module_debug.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module_debug.py @@ -120,7 +120,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim) am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim) pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions + key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions that are True in this mask will be ignored as sources in the attention weighting. attn_mask: mask of shape (seq_len, seq_len) or (seq_len, batch_size * prune_range, batch_size * prune_range), diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py b/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py index f03cc930e..1ab7a3662 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/joiner.py @@ -16,6 +16,7 @@ import torch import torch.nn as nn +from alignment_attention_module import AlignmentAttentionModule from scaling import ScaledLinear @@ -29,6 +30,7 @@ class Joiner(nn.Module): ): super().__init__() + self.label_level_am_attention = AlignmentAttentionModule() self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) self.output_linear = nn.Linear(joiner_dim, vocab_size) @@ -52,12 +54,15 @@ class Joiner(nn.Module): Returns: Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) + assert encoder_out.ndim == decoder_out.ndim, ( + encoder_out.shape, + decoder_out.shape, + ) + + encoder_out = self.label_level_am_attention(encoder_out, decoder_out) if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj( - decoder_out - ) + logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/model.py b/egs/librispeech/ASR/zipformer_label_level_algn/model.py index 823a45ae1..a7ce4e495 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/model.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/model.py @@ -21,7 +21,6 @@ from typing import Optional, Tuple import k2 import torch import torch.nn as nn -from alignment_attention_module import AlignmentAttentionModule from encoder_interface import EncoderInterface from scaling import ScaledLinear @@ -35,7 +34,6 @@ class AsrModel(nn.Module): encoder: EncoderInterface, decoder: Optional[nn.Module] = None, joiner: Optional[nn.Module] = None, - label_level_am_attention: Optional[nn.Module] = None, encoder_dim: int = 384, decoder_dim: int = 512, vocab_size: int = 500, @@ -113,8 +111,6 @@ class AsrModel(nn.Module): nn.LogSoftmax(dim=-1), ) - self.label_level_am_attention = label_level_am_attention - def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -264,14 +260,11 @@ class AsrModel(nn.Module): ranges=ranges, ) - - label_level_am_pruned = self.label_level_am_attention(am_pruned, lm_pruned) - # logits : [B, T, prune_range, vocab_size] # project_input=False since we applied the decoder's input projections # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(label_level_am_pruned, lm_pruned, project_input=False) + logits = self.joiner(am_pruned, lm_pruned, project_input=False) with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/train.py b/egs/librispeech/ASR/zipformer_label_level_algn/train.py index c43404bc8..cb0ea8ef7 100755 --- a/egs/librispeech/ASR/zipformer_label_level_algn/train.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/train.py @@ -603,16 +603,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: ) return joiner -def get_attn_module(params: AttributeDict) -> nn.Module: - attn_module = AlignmentAttentionModule() - return attn_module def get_model(params: AttributeDict) -> nn.Module: - assert ( - params.use_transducer or params.use_ctc - ), (f"At least one of them should be True, " + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}") + f"params.use_ctc={params.use_ctc}" + ) encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) @@ -624,14 +621,11 @@ def get_model(params: AttributeDict) -> nn.Module: decoder = None joiner = None - attn = get_attn_module(params) - model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, - label_level_am_attention=attn, encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, @@ -815,17 +809,16 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss += ( - simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss