diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/model.py b/egs/librispeech/ASR/zipformer_label_level_algn/model.py index b541ee697..c943ea0d2 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/model.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/model.py @@ -21,10 +21,11 @@ from typing import Optional, Tuple import k2 import torch import torch.nn as nn +from alignment_attention_module import AlignmentAttentionModule from encoder_interface import EncoderInterface +from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask -from scaling import ScaledLinear class AsrModel(nn.Module): @@ -260,11 +261,14 @@ class AsrModel(nn.Module): ranges=ranges, ) + label_level_am_attention = AlignmentAttentionModule() + label_level_am_pruned = label_level_am_attention(am_pruned, lm_pruned) + # logits : [B, T, prune_range, vocab_size] # project_input=False since we applied the decoder's input projections # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) + logits = self.joiner(label_level_am_pruned, lm_pruned, project_input=False) with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned(