changes to asr_datamodule for musan support

This commit is contained in:
Bailey Hirota 2025-07-01 18:18:25 +09:00
parent 252e5eb2e1
commit d8cb41f4f6

View File

@ -38,6 +38,12 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class MultiDatasetAsrDataModule: class MultiDatasetAsrDataModule:
""" """
@ -192,6 +198,22 @@ class MultiDatasetAsrDataModule:
transforms = [] transforms = []
input_transforms = [] input_transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap)
] + transforms
if self.args.enable_spec_aug: if self.args.enable_spec_aug:
logging.info("Enable SpecAugment") logging.info("Enable SpecAugment")
@ -245,11 +267,13 @@ class MultiDatasetAsrDataModule:
if self.args.bucketing_sampler: if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.") logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler( train_sampler = DynamicBucketingSampler( #added several new params here
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckers * 2000,
shuffle_buffer_size=self.args.num_buckers * 5000,
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
@ -265,12 +289,16 @@ class MultiDatasetAsrDataModule:
logging.info("Loading sampler state dict") logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict) train_sampler.load_state_dict(sampler_state_dict)
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader( train_dl = DataLoader(
train, train,
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=True,
worker_init_fn=worker_init_fn, #changed bottom 2 params
) )
return train_dl return train_dl
@ -332,24 +360,3 @@ class MultiDatasetAsrDataModule:
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
) )
return test_dl return test_dl
# @lru_cache()
# def train_cuts(self) -> CutSet:
# logging.info("About to get train cuts")
# return load_manifest_lazy(
# self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
# )
# @lru_cache()
# def valid_cuts(self) -> CutSet:
# logging.info("About to get dev cuts")
# return load_manifest_lazy(
# self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
# )
# @lru_cache()
# def test_cuts(self) -> List[CutSet]:
# logging.info("About to get test cuts")
# return load_manifest_lazy(
# self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
# )