diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/model.py b/egs/librispeech/ASR/zipformer_label_level_algn/model.py index c943ea0d2..cd1d63674 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/model.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/model.py @@ -112,6 +112,8 @@ class AsrModel(nn.Module): nn.LogSoftmax(dim=-1), ) + self.label_level_am_attention = AlignmentAttentionModule() + def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -261,8 +263,8 @@ class AsrModel(nn.Module): ranges=ranges, ) - label_level_am_attention = AlignmentAttentionModule() - label_level_am_pruned = label_level_am_attention(am_pruned, lm_pruned) + + label_level_am_pruned = self.label_level_am_attention(am_pruned, lm_pruned) # logits : [B, T, prune_range, vocab_size]