From 9e5ca821bc9f1ed117933b2cf947e61eb830d806 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 31 Mar 2022 20:22:18 +0800 Subject: [PATCH] Don't use a lambda for dataloader's worker_init_fn. --- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 17 +++++++++-------- .../asr_datamodule.py | 17 +++++++++-------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 8790b21e7..dc957fd04 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -41,6 +41,14 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __call__(self, worker_id: int): + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + fix_random_seed(seed + worker_id) + + class LibriSpeechAsrDataModule: """ DataModule for k2 ASR experiments. @@ -303,20 +311,13 @@ class LibriSpeechAsrDataModule: logging.info("Loading sampler state dict") train_sampler.load_state_dict(sampler_state_dict) - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - - def worker_init_fn(worker_id: int): - fix_random_seed(seed + worker_id) - train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, - worker_init_fn=worker_init_fn, + worker_init_fn=_SeedWorkers(), ) return train_dl diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py index 2ce8d8752..c3aaea782 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py @@ -41,6 +41,14 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __call__(self, worker_id: int): + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + fix_random_seed(seed + worker_id) + + class AsrDataModule: def __init__(self, args: argparse.Namespace): self.args = args @@ -256,20 +264,13 @@ class AsrDataModule: logging.info("About to create train dataloader") - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - - def worker_init_fn(worker_id: int): - fix_random_seed(seed + worker_id) - train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, - worker_init_fn=worker_init_fn, + worker_init_fn=_SeedWorkers(), ) return train_dl