diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 694ebf1d5..23798f04e 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -54,7 +54,7 @@ from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter -from transformer import Noam +from transformer import Noam, Transformer from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl @@ -138,6 +138,13 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--encoder-type", + type=str, + default="conformer", + help="Type of the encoder. Valid values are: conformer and transformer", + ) + return parser @@ -214,8 +221,16 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict): - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( + if params.encoder_type == "conformer": + Encoder = Conformer + elif params.encoder_type == "transformer": + Encoder = Transformer + else: + raise ValueError( + f"Unsupported encoder type: {params.encoder_type}" + "\nPlease use conformer or transformer" + ) + encoder = Encoder( num_features=params.feature_dim, output_dim=params.encoder_out_dim, subsampling_factor=params.subsampling_factor, @@ -360,7 +375,8 @@ def compute_loss( params: Parameters for training. See :func:`get_params`. model: - The model for training. It is an instance of Conformer in our case. + The model for training. It is an instance of Conformer or Transformer + in our case. batch: A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` for the content in it.