diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 7475632a6..5214e048e 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -251,7 +251,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=8000, + default=200, help="""Save checkpoint after processing this number of batches" periodically. We save checkpoint to exp-dir/ whenever params.batch_idx_train % save_every_n == 0. The checkpoint filename @@ -333,7 +333,7 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 50, + "log_interval": 1, "reset_interval": 200, "valid_interval": 3000, # parameters for conformer @@ -682,9 +682,9 @@ def train_one_epoch( cur_batch_idx = params.get("cur_batch_idx", 0) - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue + for batch_idx, batch in enumerate(train_dl, cur_batch_idx): + # if batch_idx < cur_batch_idx: + # continue cur_batch_idx = batch_idx params.batch_idx_train += 1 @@ -909,6 +909,7 @@ def run(rank, world_size, args): train_cuts, sampler_state_dict=sampler_state_dict ) + """ if not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, @@ -917,7 +918,7 @@ def run(rank, world_size, args): graph_compiler=graph_compiler, params=params, ) - + """ scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict")