Don't use a lambda for dataloader's worker_init_fn.

This commit is contained in:
Fangjun Kuang 2022-03-31 20:22:18 +08:00
parent 9a11808ed3
commit 9e5ca821bc
2 changed files with 18 additions and 16 deletions

View File

@ -41,6 +41,14 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __call__(self, worker_id: int):
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
fix_random_seed(seed + worker_id)
class LibriSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
@ -303,20 +311,13 @@ class LibriSpeechAsrDataModule:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
def worker_init_fn(worker_id: int):
fix_random_seed(seed + worker_id)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
worker_init_fn=_SeedWorkers(),
)
return train_dl

View File

@ -41,6 +41,14 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __call__(self, worker_id: int):
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
fix_random_seed(seed + worker_id)
class AsrDataModule:
def __init__(self, args: argparse.Namespace):
self.args = args
@ -256,20 +264,13 @@ class AsrDataModule:
logging.info("About to create train dataloader")
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
def worker_init_fn(worker_id: int):
fix_random_seed(seed + worker_id)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
worker_init_fn=_SeedWorkers(),
)
return train_dl