diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index 25a36223d..3291ad877 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -38,7 +38,7 @@ class Decoder(nn.Module): def __init__( self, vocab_size: int, - embedding_dim: int, + decoder_dim: int, blank_id: int, context_size: int, ): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index a027a5adc..e8fbb6a71 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -309,7 +309,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, - output_dim=params.embedding_dim, subsampling_factor=params.subsampling_factor, d_model=params.encoder_dim, nhead=params.nhead, @@ -322,7 +321,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, + decoder_dim=params.decoder_dim, blank_id=params.blank_id, context_size=params.context_size, ) @@ -331,7 +330,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=params.encoder_dim + encoder_dim=params.encoder_dim, decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -348,7 +347,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=params.encoder_dim + encoder_dim=params.encoder_dim, decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size,