diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 8790b21e7..8dd1459ca 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -41,6 +41,14 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + class LibriSpeechAsrDataModule: """ DataModule for k2 ASR experiments. @@ -306,9 +314,7 @@ class LibriSpeechAsrDataModule: # 'seed' is derived from the current random state, which will have # previously been set in the main process. seed = torch.randint(0, 100000, ()).item() - - def worker_init_fn(worker_id: int): - fix_random_seed(seed + worker_id) + worker_init_fn = _SeedWorkers(seed) train_dl = DataLoader( train, diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py index 2ce8d8752..c6cf739fb 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py @@ -41,6 +41,14 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + class AsrDataModule: def __init__(self, args: argparse.Namespace): self.args = args @@ -259,9 +267,7 @@ class AsrDataModule: # 'seed' is derived from the current random state, which will have # previously been set in the main process. seed = torch.randint(0, 100000, ()).item() - - def worker_init_fn(worker_id: int): - fix_random_seed(seed + worker_id) + worker_init_fn = _SeedWorkers(seed) train_dl = DataLoader( train,