mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
updated model.py to incorporate AlignmentAttnModule
output of the attn module will replace the am_pruned as the input of the joiner model
This commit is contained in:
parent
706cbae5d0
commit
ab9affd3e5
@ -21,10 +21,11 @@ from typing import Optional, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from alignment_attention_module import AlignmentAttentionModule
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
from scaling import ScaledLinear
|
|
||||||
|
|
||||||
|
|
||||||
class AsrModel(nn.Module):
|
class AsrModel(nn.Module):
|
||||||
@ -260,11 +261,14 @@ class AsrModel(nn.Module):
|
|||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
label_level_am_attention = AlignmentAttentionModule()
|
||||||
|
label_level_am_pruned = label_level_am_attention(am_pruned, lm_pruned)
|
||||||
|
|
||||||
# logits : [B, T, prune_range, vocab_size]
|
# logits : [B, T, prune_range, vocab_size]
|
||||||
|
|
||||||
# project_input=False since we applied the decoder's input projections
|
# project_input=False since we applied the decoder's input projections
|
||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
logits = self.joiner(label_level_am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user