mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
load attn param
This commit is contained in:
parent
8bc0956503
commit
056efeef30
@ -35,6 +35,7 @@ class AsrModel(nn.Module):
|
|||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: Optional[nn.Module] = None,
|
decoder: Optional[nn.Module] = None,
|
||||||
joiner: Optional[nn.Module] = None,
|
joiner: Optional[nn.Module] = None,
|
||||||
|
label_level_am_attention: Optional[nn.Module] = None,
|
||||||
encoder_dim: int = 384,
|
encoder_dim: int = 384,
|
||||||
decoder_dim: int = 512,
|
decoder_dim: int = 512,
|
||||||
vocab_size: int = 500,
|
vocab_size: int = 500,
|
||||||
@ -112,7 +113,7 @@ class AsrModel(nn.Module):
|
|||||||
nn.LogSoftmax(dim=-1),
|
nn.LogSoftmax(dim=-1),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.label_level_am_attention = AlignmentAttentionModule()
|
self.label_level_am_attention = label_level_am_attention
|
||||||
|
|
||||||
def forward_encoder(
|
def forward_encoder(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
|
|||||||
@ -65,6 +65,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from alignment_attention_module import AlignmentAttentionModule
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -602,6 +603,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
)
|
)
|
||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
def get_attn_module(params: AttributeDict) -> nn.Module:
|
||||||
|
attn_module = AlignmentAttentionModule()
|
||||||
|
return attn_module
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
assert (
|
assert (
|
||||||
@ -620,11 +624,14 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder = None
|
decoder = None
|
||||||
joiner = None
|
joiner = None
|
||||||
|
|
||||||
|
attn = get_attn_module(params)
|
||||||
|
|
||||||
model = AsrModel(
|
model = AsrModel(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
|
label_level_am_attention=attn,
|
||||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user