diff --git a/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py b/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py index 3995f21ee..581ef82a4 100644 --- a/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py +++ b/egs/multi_ja_en/ASR/local/utils/asr_datamodule.py @@ -38,6 +38,12 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) class MultiDatasetAsrDataModule: """ @@ -192,7 +198,23 @@ class MultiDatasetAsrDataModule: transforms = [] input_transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append(CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True) + else: + logging.info("Disable MUSAN") + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap) + ] + transforms + if self.args.enable_spec_aug: logging.info("Enable SpecAugment") logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") @@ -245,11 +267,13 @@ class MultiDatasetAsrDataModule: if self.args.bucketing_sampler: logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( + train_sampler = DynamicBucketingSampler( #added several new params here cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckers * 2000, + shuffle_buffer_size=self.args.num_buckers * 5000, drop_last=self.args.drop_last, ) else: @@ -265,12 +289,16 @@ class MultiDatasetAsrDataModule: logging.info("Loading sampler state dict") train_sampler.load_state_dict(sampler_state_dict) + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, - persistent_workers=False, + persistent_workers=True, + worker_init_fn=worker_init_fn, #changed bottom 2 params ) return train_dl @@ -332,24 +360,3 @@ class MultiDatasetAsrDataModule: num_workers=self.args.num_workers, ) return test_dl - - # @lru_cache() - # def train_cuts(self) -> CutSet: - # logging.info("About to get train cuts") - # return load_manifest_lazy( - # self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz" - # ) - - # @lru_cache() - # def valid_cuts(self) -> CutSet: - # logging.info("About to get dev cuts") - # return load_manifest_lazy( - # self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz" - # ) - - # @lru_cache() - # def test_cuts(self) -> List[CutSet]: - # logging.info("About to get test cuts") - # return load_manifest_lazy( - # self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz" - # )