mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
updated model.py
to incorporate AlignmentAttnModule
output of the attn module will replace the am_pruned as the input of the joiner model
This commit is contained in:
parent
706cbae5d0
commit
ab9affd3e5
@ -21,10 +21,11 @@ from typing import Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from alignment_attention_module import AlignmentAttentionModule
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
@ -260,11 +261,14 @@ class AsrModel(nn.Module):
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
label_level_am_attention = AlignmentAttentionModule()
|
||||
label_level_am_pruned = label_level_am_attention(am_pruned, lm_pruned)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
logits = self.joiner(label_level_am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
|
Loading…
x
Reference in New Issue
Block a user