Don't use a lambda for dataloader's worker_init_fn. (#284)

* Don't use a lambda for dataloader's worker_init_fn.
This commit is contained in:
Fangjun Kuang 2022-03-31 20:32:00 +08:00 committed by GitHub
parent 9a11808ed3
commit e7493ede90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 6 deletions

View File

@ -41,6 +41,14 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriSpeechAsrDataModule: class LibriSpeechAsrDataModule:
""" """
DataModule for k2 ASR experiments. DataModule for k2 ASR experiments.
@ -306,9 +314,7 @@ class LibriSpeechAsrDataModule:
# 'seed' is derived from the current random state, which will have # 'seed' is derived from the current random state, which will have
# previously been set in the main process. # previously been set in the main process.
seed = torch.randint(0, 100000, ()).item() seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
def worker_init_fn(worker_id: int):
fix_random_seed(seed + worker_id)
train_dl = DataLoader( train_dl = DataLoader(
train, train,

View File

@ -41,6 +41,14 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AsrDataModule: class AsrDataModule:
def __init__(self, args: argparse.Namespace): def __init__(self, args: argparse.Namespace):
self.args = args self.args = args
@ -259,9 +267,7 @@ class AsrDataModule:
# 'seed' is derived from the current random state, which will have # 'seed' is derived from the current random state, which will have
# previously been set in the main process. # previously been set in the main process.
seed = torch.randint(0, 100000, ()).item() seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
def worker_init_fn(worker_id: int):
fix_random_seed(seed + worker_id)
train_dl = DataLoader( train_dl = DataLoader(
train, train,