Merge branch 'rework2h_randloader' into rework2h_pow0.333

This commit is contained in:
Daniel Povey 2022-03-29 19:05:39 +08:00
commit 57f943b25c

View File

@ -22,6 +22,8 @@ import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse.utils import fix_random_seed
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
@ -301,12 +303,19 @@ class LibriSpeechAsrDataModule:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have previously been
# set in the main process.
seed = torch.randint(0, 100000, ()).item()
def worker_init_fn(worker_id: int):
fix_random_seed(seed + worker_id)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl