diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index dc957fd04..8dd1459ca 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -42,11 +42,11 @@ from icefall.utils import str2bool class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + def __call__(self, worker_id: int): - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - fix_random_seed(seed + worker_id) + fix_random_seed(self.seed + worker_id) class LibriSpeechAsrDataModule: @@ -311,13 +311,18 @@ class LibriSpeechAsrDataModule: logging.info("Loading sampler state dict") train_sampler.load_state_dict(sampler_state_dict) + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, - worker_init_fn=_SeedWorkers(), + worker_init_fn=worker_init_fn, ) return train_dl diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py index c3aaea782..c6cf739fb 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py @@ -42,11 +42,11 @@ from icefall.utils import str2bool class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + def __call__(self, worker_id: int): - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - fix_random_seed(seed + worker_id) + fix_random_seed(self.seed + worker_id) class AsrDataModule: @@ -264,13 +264,18 @@ class AsrDataModule: logging.info("About to create train dataloader") + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, - worker_init_fn=_SeedWorkers(), + worker_init_fn=worker_init_fn, ) return train_dl