From 1138b27f167067b53f75be9fa1c038fe35c9c01b Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 10 Aug 2022 17:16:09 +0800 Subject: [PATCH] modify pretrained.py --- .../ASR/lstm_transducer_stateless/export.py | 2 +- .../lstm_transducer_stateless/pretrained.py | 28 ++++++++++--------- .../ASR/lstm_transducer_stateless/train.py | 9 +++--- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py index 49ba93d55..9fa841bcc 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -358,9 +358,9 @@ def main(): model.to("cpu") model.eval() - convert_scaled_to_non_scaled(model, inplace=True) if params.jit_trace is True: + convert_scaled_to_non_scaled(model, inplace=True) logging.info("Using torch.jit.trace()") encoder_filename = params.exp_dir / "encoder_jit_trace.pt" export_encoder_model_jit_trace(model.encoder, encoder_filename) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py index 21bcf7cfd..2a6e2adc6 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -18,16 +18,16 @@ Usage: (1) greedy search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ --bpe-model ./data/lang_bpe_500/bpe.model \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav (2) beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ --bpe-model ./data/lang_bpe_500/bpe.model \ --method beam_search \ --beam-size 4 \ @@ -35,8 +35,8 @@ Usage: /path/to/bar.wav (3) modified beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ --bpe-model ./data/lang_bpe_500/bpe.model \ --method modified_beam_search \ --beam-size 4 \ @@ -44,18 +44,18 @@ Usage: /path/to/bar.wav (4) fast beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ --bpe-model ./data/lang_bpe_500/bpe.model \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ /path/to/bar.wav -You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`. +You can also use `./lstm_transducer_stateless/exp/epoch-xx.pt`. -Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by -./pruned_transducer_stateless2/export.py +Note: ./lstm_transducer_stateless/exp/pretrained.pt is generated by +./lstm_transducer_stateless/export.py """ @@ -77,7 +77,7 @@ from beam_search import ( modified_beam_search, ) from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model def get_parser(): @@ -178,6 +178,8 @@ def get_parser(): """, ) + add_model_arguments(parser) + return parser @@ -268,7 +270,7 @@ def main(): feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder( + encoder_out, encoder_out_lens, _ = model.encoder( x=features, x_lens=feature_lengths ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index 89bd406b1..8d07aae5e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -111,9 +111,10 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--aux-layer-period", type=int, - default=3, + default=0, help="""Peroid of auxiliary layers used for randomly combined during training. - If not larger than 0 (e.g., -1), will not use the random combiner. + If set to 0, will not use the random combiner (Default). + You can set a positive integer to use the random combiner, e.g., 3. """, ) @@ -206,7 +207,7 @@ def get_parser(): parser.add_argument( "--lr-epochs", type=float, - default=6, + default=10, help="""Number of epochs that affects how rapidly the learning rate decreases. """, ) @@ -270,7 +271,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=8000, + default=4000, help="""Save checkpoint after processing this number of batches" periodically. We save checkpoint to exp-dir/ whenever params.batch_idx_train % save_every_n == 0. The checkpoint filename