From 1c35ae1dba719c03aef3be9198840ca621c131bb Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 21 Feb 2022 15:16:39 +0800 Subject: [PATCH] Reset seed at the beginning of each epoch. (#221) * Reset seed at the beginning of each epoch. * Use a different seed for each epoch. --- egs/aishell/ASR/conformer_ctc/train.py | 10 +++++++++- egs/aishell/ASR/conformer_mmi/train.py | 10 +++++++++- egs/aishell/ASR/tdnn_lstm_ctc/train.py | 10 +++++++++- egs/aishell/ASR/transducer_stateless/train.py | 10 +++++++++- egs/librispeech/ASR/conformer_ctc/train.py | 10 +++++++++- egs/librispeech/ASR/conformer_mmi/train.py | 10 +++++++++- .../ASR/pruned_transducer_stateless/train.py | 10 +++++++++- egs/librispeech/ASR/streaming_conformer_ctc/train.py | 10 +++++++++- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 10 +++++++++- egs/librispeech/ASR/transducer/train.py | 10 +++++++++- egs/librispeech/ASR/transducer_lstm/train.py | 10 +++++++++- egs/librispeech/ASR/transducer_stateless/train.py | 10 +++++++++- egs/timit/ASR/tdnn_ligru_ctc/train.py | 10 +++++++++- egs/timit/ASR/tdnn_lstm_ctc/train.py | 10 +++++++++- egs/yesno/ASR/tdnn/train.py | 10 +++++++++- egs/yesno/ASR/transducer/train.py | 10 +++++++++- 16 files changed, 144 insertions(+), 16 deletions(-) diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index a4bc8e3bb..369ad310f 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -121,6 +121,13 @@ def get_parser(): """, ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -555,7 +562,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -618,6 +625,7 @@ def run(rank, world_size, args): valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate diff --git a/egs/aishell/ASR/conformer_mmi/train.py b/egs/aishell/ASR/conformer_mmi/train.py index 79c16d1cc..685831d09 100755 --- a/egs/aishell/ASR/conformer_mmi/train.py +++ b/egs/aishell/ASR/conformer_mmi/train.py @@ -124,6 +124,13 @@ def get_parser(): """, ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -546,7 +553,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -613,6 +620,7 @@ def run(rank, world_size, args): valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py index a0045115d..3327cdb79 100755 --- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/train.py @@ -92,6 +92,13 @@ def get_parser(): """, ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -507,7 +514,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -557,6 +564,7 @@ def run(rank, world_size, args): valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) if epoch > params.start_epoch: diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index cd37810dd..f615c78f4 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -129,6 +129,13 @@ def get_parser(): "2 means tri-gram", ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -534,7 +541,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -611,6 +618,7 @@ def run(rank, world_size, args): valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 058efd061..b81bd6330 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -140,6 +140,13 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -580,7 +587,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -684,6 +691,7 @@ def run(rank, world_size, args): ) for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index c36677762..9a5bdcce2 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -109,6 +109,13 @@ def get_parser(): """, ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -673,7 +680,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -761,6 +768,7 @@ def run(rank, world_size, args): valid_dl = librispeech.valid_dataloaders() for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) if ( params.batch_idx_train >= params.use_ali_until diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index e19473788..f0ea2ccaa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -179,6 +179,13 @@ def get_parser(): "with this parameter before adding to the final loss.", ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -642,7 +649,7 @@ def run(rank, world_size, args): params.valid_interval = 800 params.warm_step = 30000 - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -731,6 +738,7 @@ def run(rank, world_size, args): ) for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py index 8b4d6701e..9beb185a2 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -138,6 +138,13 @@ def get_parser(): help="Proportion of samples trained with short right context", ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -575,7 +582,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -645,6 +652,7 @@ def run(rank, world_size, args): ) for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 7439e157a..8597525ba 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -95,6 +95,13 @@ def get_parser(): """, ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -486,7 +493,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -544,6 +551,7 @@ def run(rank, world_size, args): valid_dl = librispeech.valid_dataloaders(valid_cuts) for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) if epoch > params.start_epoch: diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index 903ba8491..a6ce79520 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -130,6 +130,13 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -544,7 +551,7 @@ def run(rank, world_size, args): params.valid_interval = 800 params.warm_step = 8000 - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -633,6 +640,7 @@ def run(rank, world_size, args): ) for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 62e9b5b12..9f06ed512 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -131,6 +131,13 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -548,7 +555,7 @@ def run(rank, world_size, args): params.valid_interval = 800 params.warm_step = 8000 - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -639,6 +646,7 @@ def run(rank, world_size, args): ) for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 544f6e9b1..4f5379e53 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -149,6 +149,13 @@ def get_parser(): """, ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -562,7 +569,7 @@ def run(rank, world_size, args): params.valid_interval = 800 params.warm_step = 8000 - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -651,6 +658,7 @@ def run(rank, world_size, args): ) for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) cur_lr = optimizer._rate diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py index 9ac4743b4..452c2a7cb 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/train.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py @@ -95,6 +95,13 @@ def get_parser(): """, ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -486,7 +493,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -536,6 +543,7 @@ def run(rank, world_size, args): valid_dl = timit.valid_dataloaders() for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) if epoch > params.start_epoch: diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py index 2a6ff4787..849256b98 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/train.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py @@ -95,6 +95,13 @@ def get_parser(): """, ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -486,7 +493,7 @@ def run(rank, world_size, args): params = get_params() params.update(vars(args)) - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -536,6 +543,7 @@ def run(rank, world_size, args): valid_dl = timit.valid_dataloaders() for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) if epoch > params.start_epoch: diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index d8454b7c5..f32a27f35 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -71,6 +71,13 @@ def get_parser(): """, ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -468,7 +475,7 @@ def run(rank, world_size, args): params.update(vars(args)) params["env_info"] = get_env_info() - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -520,6 +527,7 @@ def run(rank, world_size, args): valid_dl = yes_no.test_dataloaders() for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) if tb_writer is not None: diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py index 7d2d1edeb..deb92107d 100755 --- a/egs/yesno/ASR/transducer/train.py +++ b/egs/yesno/ASR/transducer/train.py @@ -114,6 +114,13 @@ def get_parser(): help="Directory to save results", ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + return parser @@ -487,7 +494,7 @@ def run(rank, world_size, args): params.update(vars(args)) params["env_info"] = get_env_info() - fix_random_seed(42) + fix_random_seed(params.seed) if world_size > 1: setup_dist(rank, world_size, params.master_port) @@ -532,6 +539,7 @@ def run(rank, world_size, args): valid_dl = yes_no.test_dataloaders() for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) if tb_writer is not None: