diff --git a/egs/librispeech/ASR/pruned_transducer_stateless-2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless-2/decode.py index 8e924bf96..8b48b847f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless-2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless-2/decode.py @@ -74,7 +74,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -197,6 +197,8 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + add_model_arguments(parser) + return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless-2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless-2/export.py index 7d2a07817..f4f4eb1c0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless-2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless-2/export.py @@ -49,7 +49,7 @@ from pathlib import Path import sentencepiece as spm import torch -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.utils import str2bool @@ -109,6 +109,8 @@ def get_parser(): "2 means tri-gram", ) + add_model_arguments(parser) + return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless-2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless-2/pretrained.py index b0eb4d749..4ce6a668b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless-2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless-2/pretrained.py @@ -57,7 +57,7 @@ from beam_search import ( modified_beam_search, ) from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model def get_parser(): @@ -133,6 +133,8 @@ def get_parser(): """, ) + add_model_arguments(parser) + return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless-2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless-2/train.py index 5520d0168..5bbf35658 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless-2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless-2/train.py @@ -52,7 +52,6 @@ from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam @@ -73,6 +72,44 @@ from icefall.utils import ( ) +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of transformer encoder layers", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads in a transformer encoder layer", + ) + + parser.add_argument( + "--dim-feedfoward", + type=int, + default=2048, + help="Feedforward dimension of linear layers after attention in " + "the transformer model", + ) + + parser.add_argument( + "--attention-dim", + type=int, + default=512, + help="Attention dimension in a transformer encoder layer", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=512, + help="Embedding dimension for the decoder network", + ) + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -229,6 +266,8 @@ def get_parser(): help="Accumulate stats on activations, print them and exit.", ) + add_model_arguments(parser) + return parser @@ -270,10 +309,6 @@ def get_params() -> AttributeDict: - subsampling_factor: The subsampling factor for the model. - - attention_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( @@ -290,13 +325,7 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, "vgg_frontend": False, - # parameters for decoder - "embedding_dim": 512, # parameters for Noam "warm_step": 80000, # For the 100h subset, use 30000 "env_info": get_env_info(),