From 4b567e480f854aa13c8400f13e70856974d1d5a4 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Wed, 27 Apr 2022 13:43:54 +0800 Subject: [PATCH] adapt for S, M and L training subset --- .../asr_datamodule.py | 15 +++++++---- .../pruned_transducer_stateless2/decode.py | 14 ++++++++++ .../ASR/pruned_transducer_stateless2/train.py | 26 ++++++++++++++++--- 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 744db8109..e969ffaeb 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -198,6 +198,13 @@ class WenetSpeechAsrDataModule: help="lazily open CutSets to avoid OOM (for L|XL subset)", ) + group.add_argument( + "--training-subset", + type=str, + default="L", + help="The training subset for using", + ) + def train_dataloaders( self, cuts_train: CutSet, @@ -316,6 +323,7 @@ class WenetSpeechAsrDataModule: if sampler_state_dict is not None: logging.info("Loading sampler state dict") + print(sampler_state_dict.keys()) train_sampler.load_state_dict(sampler_state_dict) # 'seed' is derived from the current random state, which will have @@ -396,7 +404,6 @@ class WenetSpeechAsrDataModule: world_size=1, shuffle=False, ) - from lhotse.dataset.iterable_dataset import IterableDatasetWrapper test_iter_dataset = IterableDatasetWrapper( @@ -417,14 +424,12 @@ class WenetSpeechAsrDataModule: logging.info("use lazy cuts") cuts_train = CutSet.from_jsonl_lazy( self.args.manifest_dir - / "cuts_L.jsonl.gz" - # use cuts_L_50_pieces.jsonl.gz for original experiments + / f"cuts_{self.args.training_subset}.jsonl.gz" ) else: cuts_train = CutSet.from_file( self.args.manifest_dir - / "cuts_L.jsonl.gz" - # use cuts_L_50_pieces.jsonl.gz for original experiments + / f"cuts_{self.args.training_subset}.jsonl.gz" ) return cuts_train diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index 49be13d43..ab8262a71 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -101,6 +101,15 @@ def get_parser(): help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) + + parser.add_argument( + "--batch", + type=int, + default=None, + help="It specifies the batch checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( "--avg", type=int, @@ -499,6 +508,11 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + elif params.batch is not None: + filenames = f"{params.exp_dir}/checkpoint-{params.batch}.pt" + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints([filenames], device=device)) else: start = params.epoch - params.avg + 1 filenames = [] diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 5214e048e..27b889a91 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -251,7 +251,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=200, + default=8000, help="""Save checkpoint after processing this number of batches" periodically. We save checkpoint to exp-dir/ whenever params.batch_idx_train % save_every_n == 0. The checkpoint filename @@ -279,6 +279,26 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--valid-interval", + type=int, + default=3000, + help="""When training_subset is L, set the valid_interval to 3000. + When training_subset is M, set the valid_interval to 1000. + When training_subset is S, set the valid_interval to 400. + """, + ) + + parser.add_argument( + "--model-warm-step", + type=int, + default=3000, + help="""When training_subset is L, set the model_warm_step to 3000. + When training_subset is M, set the model_warm_step to 500. + When training_subset is S, set the model_warm_step to 100. + """, + ) + return parser @@ -333,9 +353,8 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 1, + "log_interval": 50, "reset_interval": 200, - "valid_interval": 3000, # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, @@ -348,7 +367,6 @@ def get_params() -> AttributeDict: # parameters for joiner "joiner_dim": 512, # parameters for Noam - "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), } )