mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
load attn param
This commit is contained in:
parent
8bc0956503
commit
056efeef30
@ -35,6 +35,7 @@ class AsrModel(nn.Module):
|
||||
encoder: EncoderInterface,
|
||||
decoder: Optional[nn.Module] = None,
|
||||
joiner: Optional[nn.Module] = None,
|
||||
label_level_am_attention: Optional[nn.Module] = None,
|
||||
encoder_dim: int = 384,
|
||||
decoder_dim: int = 512,
|
||||
vocab_size: int = 500,
|
||||
@ -112,7 +113,7 @@ class AsrModel(nn.Module):
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
self.label_level_am_attention = AlignmentAttentionModule()
|
||||
self.label_level_am_attention = label_level_am_attention
|
||||
|
||||
def forward_encoder(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
|
@ -65,6 +65,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from alignment_attention_module import AlignmentAttentionModule
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -602,6 +603,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
)
|
||||
return joiner
|
||||
|
||||
def get_attn_module(params: AttributeDict) -> nn.Module:
|
||||
attn_module = AlignmentAttentionModule()
|
||||
return attn_module
|
||||
|
||||
def get_model(params: AttributeDict) -> nn.Module:
|
||||
assert (
|
||||
@ -620,11 +624,14 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = None
|
||||
joiner = None
|
||||
|
||||
attn = get_attn_module(params)
|
||||
|
||||
model = AsrModel(
|
||||
encoder_embed=encoder_embed,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
label_level_am_attention=attn,
|
||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||
decoder_dim=params.decoder_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
|
Loading…
x
Reference in New Issue
Block a user