mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Don't use a lambda for dataloader's worker_init_fn. (#284)
* Don't use a lambda for dataloader's worker_init_fn.
This commit is contained in:
parent
9a11808ed3
commit
e7493ede90
@ -41,6 +41,14 @@ from torch.utils.data import DataLoader
|
|||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, worker_id: int):
|
||||||
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
class LibriSpeechAsrDataModule:
|
class LibriSpeechAsrDataModule:
|
||||||
"""
|
"""
|
||||||
DataModule for k2 ASR experiments.
|
DataModule for k2 ASR experiments.
|
||||||
@ -306,9 +314,7 @@ class LibriSpeechAsrDataModule:
|
|||||||
# 'seed' is derived from the current random state, which will have
|
# 'seed' is derived from the current random state, which will have
|
||||||
# previously been set in the main process.
|
# previously been set in the main process.
|
||||||
seed = torch.randint(0, 100000, ()).item()
|
seed = torch.randint(0, 100000, ()).item()
|
||||||
|
worker_init_fn = _SeedWorkers(seed)
|
||||||
def worker_init_fn(worker_id: int):
|
|
||||||
fix_random_seed(seed + worker_id)
|
|
||||||
|
|
||||||
train_dl = DataLoader(
|
train_dl = DataLoader(
|
||||||
train,
|
train,
|
||||||
|
@ -41,6 +41,14 @@ from torch.utils.data import DataLoader
|
|||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, worker_id: int):
|
||||||
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
class AsrDataModule:
|
class AsrDataModule:
|
||||||
def __init__(self, args: argparse.Namespace):
|
def __init__(self, args: argparse.Namespace):
|
||||||
self.args = args
|
self.args = args
|
||||||
@ -259,9 +267,7 @@ class AsrDataModule:
|
|||||||
# 'seed' is derived from the current random state, which will have
|
# 'seed' is derived from the current random state, which will have
|
||||||
# previously been set in the main process.
|
# previously been set in the main process.
|
||||||
seed = torch.randint(0, 100000, ()).item()
|
seed = torch.randint(0, 100000, ()).item()
|
||||||
|
worker_init_fn = _SeedWorkers(seed)
|
||||||
def worker_init_fn(worker_id: int):
|
|
||||||
fix_random_seed(seed + worker_id)
|
|
||||||
|
|
||||||
train_dl = DataLoader(
|
train_dl = DataLoader(
|
||||||
train,
|
train,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user