From ab9affd3e5ecc9796b017a61fbb278b2cb149b0d Mon Sep 17 00:00:00 2001 From: zr_jin <60612200+JinZr@users.noreply.github.com> Date: Sun, 23 Jul 2023 18:10:39 +0800 Subject: [PATCH] updated `model.py` to incorporate AlignmentAttnModule output of the attn module will replace the am_pruned as the input of the joiner model --- egs/librispeech/ASR/zipformer_label_level_algn/model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/model.py b/egs/librispeech/ASR/zipformer_label_level_algn/model.py index b541ee697..c943ea0d2 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/model.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/model.py @@ -21,10 +21,11 @@ from typing import Optional, Tuple import k2 import torch import torch.nn as nn +from alignment_attention_module import AlignmentAttentionModule from encoder_interface import EncoderInterface +from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask -from scaling import ScaledLinear class AsrModel(nn.Module): @@ -260,11 +261,14 @@ class AsrModel(nn.Module): ranges=ranges, ) + label_level_am_attention = AlignmentAttentionModule() + label_level_am_pruned = label_level_am_attention(am_pruned, lm_pruned) + # logits : [B, T, prune_range, vocab_size] # project_input=False since we applied the decoder's input projections # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) + logits = self.joiner(label_level_am_pruned, lm_pruned, project_input=False) with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned(