mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
changes to asr_datamodule for musan support
This commit is contained in:
parent
252e5eb2e1
commit
d8cb41f4f6
@ -38,6 +38,12 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
class MultiDatasetAsrDataModule:
|
||||
"""
|
||||
@ -192,7 +198,23 @@ class MultiDatasetAsrDataModule:
|
||||
|
||||
transforms = []
|
||||
input_transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap)
|
||||
] + transforms
|
||||
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
@ -245,11 +267,13 @@ class MultiDatasetAsrDataModule:
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
train_sampler = DynamicBucketingSampler( #added several new params here
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckers * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckers * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
@ -265,12 +289,16 @@ class MultiDatasetAsrDataModule:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
persistent_workers=True,
|
||||
worker_init_fn=worker_init_fn, #changed bottom 2 params
|
||||
)
|
||||
|
||||
return train_dl
|
||||
@ -332,24 +360,3 @@ class MultiDatasetAsrDataModule:
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
# @lru_cache()
|
||||
# def train_cuts(self) -> CutSet:
|
||||
# logging.info("About to get train cuts")
|
||||
# return load_manifest_lazy(
|
||||
# self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
|
||||
# )
|
||||
|
||||
# @lru_cache()
|
||||
# def valid_cuts(self) -> CutSet:
|
||||
# logging.info("About to get dev cuts")
|
||||
# return load_manifest_lazy(
|
||||
# self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
|
||||
# )
|
||||
|
||||
# @lru_cache()
|
||||
# def test_cuts(self) -> List[CutSet]:
|
||||
# logging.info("About to get test cuts")
|
||||
# return load_manifest_lazy(
|
||||
# self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
|
||||
# )
|
||||
|
Loading…
x
Reference in New Issue
Block a user