Add random-number-setting function in dataloader

This commit is contained in:
Daniel Povey 2022-03-26 14:46:27 +08:00
parent 0e694739f2
commit 26a1730392

View File

@ -22,6 +22,8 @@ import logging
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch
import lhotse
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
@ -301,12 +303,19 @@ class LibriSpeechAsrDataModule:
logging.info("Loading sampler state dict") logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict) train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have previously been
# set in the main process.
seed = torch.randint(0, 100000, ()).item()
def worker_init_fn(worker_id: int):
lhotse.utils.fix_random_seed(seed + worker_id)
train_dl = DataLoader( train_dl = DataLoader(
train, train,
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=False,
worker_init_fn=worker_init_fn,
) )
return train_dl return train_dl