load attn param

This commit is contained in:
zr_jin 2023-07-23 22:18:25 +08:00
parent 8bc0956503
commit 056efeef30
2 changed files with 9 additions and 1 deletions

View File

@ -35,6 +35,7 @@ class AsrModel(nn.Module):
encoder: EncoderInterface,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
label_level_am_attention: Optional[nn.Module] = None,
encoder_dim: int = 384,
decoder_dim: int = 512,
vocab_size: int = 500,
@ -112,7 +113,7 @@ class AsrModel(nn.Module):
nn.LogSoftmax(dim=-1),
)
self.label_level_am_attention = AlignmentAttentionModule()
self.label_level_am_attention = label_level_am_attention
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor

View File

@ -65,6 +65,7 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from alignment_attention_module import AlignmentAttentionModule
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from joiner import Joiner
@ -602,6 +603,9 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
)
return joiner
def get_attn_module(params: AttributeDict) -> nn.Module:
attn_module = AlignmentAttentionModule()
return attn_module
def get_model(params: AttributeDict) -> nn.Module:
assert (
@ -620,11 +624,14 @@ def get_model(params: AttributeDict) -> nn.Module:
decoder = None
joiner = None
attn = get_attn_module(params)
model = AsrModel(
encoder_embed=encoder_embed,
encoder=encoder,
decoder=decoder,
joiner=joiner,
label_level_am_attention=attn,
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,