mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
Fixes after comment.
This commit is contained in:
parent
9e5ca821bc
commit
4f87802200
@ -42,11 +42,11 @@ from icefall.utils import str2bool
|
|||||||
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
def __call__(self, worker_id: int):
|
def __call__(self, worker_id: int):
|
||||||
# 'seed' is derived from the current random state, which will have
|
fix_random_seed(self.seed + worker_id)
|
||||||
# previously been set in the main process.
|
|
||||||
seed = torch.randint(0, 100000, ()).item()
|
|
||||||
fix_random_seed(seed + worker_id)
|
|
||||||
|
|
||||||
|
|
||||||
class LibriSpeechAsrDataModule:
|
class LibriSpeechAsrDataModule:
|
||||||
@ -311,13 +311,18 @@ class LibriSpeechAsrDataModule:
|
|||||||
logging.info("Loading sampler state dict")
|
logging.info("Loading sampler state dict")
|
||||||
train_sampler.load_state_dict(sampler_state_dict)
|
train_sampler.load_state_dict(sampler_state_dict)
|
||||||
|
|
||||||
|
# 'seed' is derived from the current random state, which will have
|
||||||
|
# previously been set in the main process.
|
||||||
|
seed = torch.randint(0, 100000, ()).item()
|
||||||
|
worker_init_fn = _SeedWorkers(seed)
|
||||||
|
|
||||||
train_dl = DataLoader(
|
train_dl = DataLoader(
|
||||||
train,
|
train,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
persistent_workers=False,
|
persistent_workers=False,
|
||||||
worker_init_fn=_SeedWorkers(),
|
worker_init_fn=worker_init_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
|
@ -42,11 +42,11 @@ from icefall.utils import str2bool
|
|||||||
|
|
||||||
|
|
||||||
class _SeedWorkers:
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
def __call__(self, worker_id: int):
|
def __call__(self, worker_id: int):
|
||||||
# 'seed' is derived from the current random state, which will have
|
fix_random_seed(self.seed + worker_id)
|
||||||
# previously been set in the main process.
|
|
||||||
seed = torch.randint(0, 100000, ()).item()
|
|
||||||
fix_random_seed(seed + worker_id)
|
|
||||||
|
|
||||||
|
|
||||||
class AsrDataModule:
|
class AsrDataModule:
|
||||||
@ -264,13 +264,18 @@ class AsrDataModule:
|
|||||||
|
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
|
# 'seed' is derived from the current random state, which will have
|
||||||
|
# previously been set in the main process.
|
||||||
|
seed = torch.randint(0, 100000, ()).item()
|
||||||
|
worker_init_fn = _SeedWorkers(seed)
|
||||||
|
|
||||||
train_dl = DataLoader(
|
train_dl = DataLoader(
|
||||||
train,
|
train,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
persistent_workers=False,
|
persistent_workers=False,
|
||||||
worker_init_fn=_SeedWorkers(),
|
worker_init_fn=worker_init_fn,
|
||||||
)
|
)
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user