diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index e5987b75e..f581fa13d 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -51,6 +51,7 @@ from conformer import Conformer from decoder import Decoder from joiner import Joiner from model import Transducer +from transformer import Transformer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info @@ -128,6 +129,13 @@ def get_parser(): help="Maximum number of symbols per frame", ) + parser.add_argument( + "--encoder-type", + type=str, + default="conformer", + help="Type of the encoder. Valid values are: conformer and transformer", + ) + return parser @@ -150,8 +158,16 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict): - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( + if params.encoder_type == "conformer": + Encoder = Conformer + elif params.encoder_type == "transformer": + Encoder = Transformer + else: + raise ValueError( + f"Unsupported encoder type: {params.encoder_type}" + "\nPlease use conformer or transformer" + ) + encoder = Encoder( num_features=params.feature_dim, output_dim=params.encoder_out_dim, subsampling_factor=params.subsampling_factor,