updated model.py to incorporate AlignmentAttnModule

output of the attn module will replace the am_pruned as the input of the joiner model
This commit is contained in:
zr_jin 2023-07-23 18:10:39 +08:00
parent 706cbae5d0
commit ab9affd3e5

View File

@ -21,10 +21,11 @@ from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from alignment_attention_module import AlignmentAttentionModule
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class AsrModel(nn.Module):
@ -260,11 +261,14 @@ class AsrModel(nn.Module):
ranges=ranges,
)
label_level_am_attention = AlignmentAttentionModule()
label_level_am_pruned = label_level_am_attention(am_pruned, lm_pruned)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
logits = self.joiner(label_level_am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(