From d79f5fecf74c69172fe1e9d5104c9f9bb6f4bbb3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 25 Apr 2022 17:26:43 +0800 Subject: [PATCH] Pass model parameters from the command line. --- .../pruned_transducer_stateless4/decode.py | 22 ++++--- .../test_model.py | 22 ++++++- .../ASR/pruned_transducer_stateless4/train.py | 64 +++++++++++++++---- 3 files changed, 85 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 38aff8834..e706083ff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -18,36 +18,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless2/decode.py \ +./pruned_transducer_stateless4/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless4/exp \ --max-duration 100 \ --decoding-method greedy_search (2) beam search -./pruned_transducer_stateless2/decode.py \ +./pruned_transducer_stateless4/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless4/exp \ --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless2/decode.py \ +./pruned_transducer_stateless4/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless4/exp \ --max-duration 100 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search -./pruned_transducer_stateless2/decode.py \ +./pruned_transducer_stateless4/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless4/exp \ --max-duration 1500 \ --decoding-method fast_beam_search \ --beam 4 \ @@ -74,7 +74,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, @@ -124,7 +124,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="pruned_transducer_stateless4/exp", help="The experiment dir", ) @@ -197,6 +197,8 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + add_model_arguments(parser) + return parser diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py index 43f84e5c7..9aad32014 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/test_model.py @@ -26,7 +26,7 @@ To run this file, do: from train import get_params, get_transducer_model -def test_model(): +def test_model_1(): params = get_params() params.vocab_size = 500 params.blank_id = 0 @@ -39,8 +39,26 @@ def test_model(): print(f"Number of model parameters: {num_param}") +# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf +def test_model_M(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = 18 + params.dim_feedforward = 1024 + params.encoder_dim = 256 + params.nhead = 4 + params.decoder_dim = 512 + params.joiner_dim = 512 + model = get_transducer_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + def main(): - test_model() + # test_model_1() + test_model_M() if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 31617c3b0..b759e77ab 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -83,6 +83,53 @@ LRSchedulerType = Union[ ] +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=int, + default=24, + help="Number of conformer encoder layers..", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=1536, + help="Feedforward dimension of the conformer encoder layer.", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads in the conformer encoder layer.", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=384, + help="Attention dimension in the conformer encoder layer.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -156,15 +203,16 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( "--lr-batches", type=float, default=5000, - help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""", + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", ) parser.add_argument( @@ -262,6 +310,8 @@ def get_parser(): help="Whether to use half precision training.", ) + add_model_arguments(parser) + return parser @@ -322,14 +372,6 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "encoder_dim": 384, - "nhead": 8, - "dim_feedforward": 1536, - "num_encoder_layers": 24, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, # parameters for Noam "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(),