Various bug fixes

This commit is contained in:
Daniel Povey 2022-04-02 20:06:43 +08:00
parent 8be10d3d6c
commit 34500afc43
2 changed files with 4 additions and 5 deletions

View File

@ -38,7 +38,7 @@ class Decoder(nn.Module):
def __init__( def __init__(
self, self,
vocab_size: int, vocab_size: int,
embedding_dim: int, decoder_dim: int,
blank_id: int, blank_id: int,
context_size: int, context_size: int,
): ):

View File

@ -309,7 +309,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.embedding_dim,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
d_model=params.encoder_dim, d_model=params.encoder_dim,
nhead=params.nhead, nhead=params.nhead,
@ -322,7 +321,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
def get_decoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim, decoder_dim=params.decoder_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
context_size=params.context_size, context_size=params.context_size,
) )
@ -331,7 +330,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
encoder_dim=params.encoder_dim encoder_dim=params.encoder_dim,
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -348,7 +347,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=params.encoder_dim encoder_dim=params.encoder_dim,
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,