mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 14:14:19 +00:00
Don't use a lambda for dataloader's worker_init_fn.
This commit is contained in:
parent
9a11808ed3
commit
9e5ca821bc
@ -41,6 +41,14 @@ from torch.utils.data import DataLoader
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __call__(self, worker_id: int):
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
fix_random_seed(seed + worker_id)
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
@ -303,20 +311,13 @@ class LibriSpeechAsrDataModule:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
|
||||
def worker_init_fn(worker_id: int):
|
||||
fix_random_seed(seed + worker_id)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
worker_init_fn=_SeedWorkers(),
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
@ -41,6 +41,14 @@ from torch.utils.data import DataLoader
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __call__(self, worker_id: int):
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
fix_random_seed(seed + worker_id)
|
||||
|
||||
|
||||
class AsrDataModule:
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
@ -256,20 +264,13 @@ class AsrDataModule:
|
||||
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
|
||||
def worker_init_fn(worker_id: int):
|
||||
fix_random_seed(seed + worker_id)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
worker_init_fn=_SeedWorkers(),
|
||||
)
|
||||
return train_dl
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user