updated model.py to incorporate AlignmentAttnModule

output of the attn module will replace the am_pruned as the input of the joiner model
This commit is contained in:
zr_jin 2023-07-23 18:10:39 +08:00
parent 706cbae5d0
commit ab9affd3e5

View File

@ -21,10 +21,11 @@ from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from alignment_attention_module import AlignmentAttentionModule
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class AsrModel(nn.Module): class AsrModel(nn.Module):
@ -260,11 +261,14 @@ class AsrModel(nn.Module):
ranges=ranges, ranges=ranges,
) )
label_level_am_attention = AlignmentAttentionModule()
label_level_am_pruned = label_level_am_attention(am_pruned, lm_pruned)
# logits : [B, T, prune_range, vocab_size] # logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections # project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False) logits = self.joiner(label_level_am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(