mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
changes to asr_datamodule for musan support
This commit is contained in:
parent
252e5eb2e1
commit
d8cb41f4f6
@ -38,6 +38,12 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, worker_id: int):
|
||||||
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
class MultiDatasetAsrDataModule:
|
class MultiDatasetAsrDataModule:
|
||||||
"""
|
"""
|
||||||
@ -192,6 +198,22 @@ class MultiDatasetAsrDataModule:
|
|||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
input_transforms = []
|
input_transforms = []
|
||||||
|
if self.args.enable_musan:
|
||||||
|
logging.info("Enable MUSAN")
|
||||||
|
logging.info("About to get Musan cuts")
|
||||||
|
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||||
|
transforms.append(CutMix(cuts=cuts_musan, p=0.5, snr=(10,20), preserve_id=True)
|
||||||
|
else:
|
||||||
|
logging.info("Disable MUSAN")
|
||||||
|
|
||||||
|
if self.args.concatenate_cuts:
|
||||||
|
logging.info(
|
||||||
|
f"Using cut concatenation with duration factor "
|
||||||
|
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||||
|
transforms = [
|
||||||
|
CutConcatenate(
|
||||||
|
duration_factor=self.args.duration_factor, gap=self.args.gap)
|
||||||
|
] + transforms
|
||||||
|
|
||||||
if self.args.enable_spec_aug:
|
if self.args.enable_spec_aug:
|
||||||
logging.info("Enable SpecAugment")
|
logging.info("Enable SpecAugment")
|
||||||
@ -245,11 +267,13 @@ class MultiDatasetAsrDataModule:
|
|||||||
|
|
||||||
if self.args.bucketing_sampler:
|
if self.args.bucketing_sampler:
|
||||||
logging.info("Using DynamicBucketingSampler.")
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
train_sampler = DynamicBucketingSampler(
|
train_sampler = DynamicBucketingSampler( #added several new params here
|
||||||
cuts_train,
|
cuts_train,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
|
buffer_size=self.args.num_buckers * 2000,
|
||||||
|
shuffle_buffer_size=self.args.num_buckers * 5000,
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -265,12 +289,16 @@ class MultiDatasetAsrDataModule:
|
|||||||
logging.info("Loading sampler state dict")
|
logging.info("Loading sampler state dict")
|
||||||
train_sampler.load_state_dict(sampler_state_dict)
|
train_sampler.load_state_dict(sampler_state_dict)
|
||||||
|
|
||||||
|
seed = torch.randint(0, 100000, ()).item()
|
||||||
|
worker_init_fn = _SeedWorkers(seed)
|
||||||
|
|
||||||
train_dl = DataLoader(
|
train_dl = DataLoader(
|
||||||
train,
|
train,
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
batch_size=None,
|
batch_size=None,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
persistent_workers=False,
|
persistent_workers=True,
|
||||||
|
worker_init_fn=worker_init_fn, #changed bottom 2 params
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
@ -332,24 +360,3 @@ class MultiDatasetAsrDataModule:
|
|||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
)
|
)
|
||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
# @lru_cache()
|
|
||||||
# def train_cuts(self) -> CutSet:
|
|
||||||
# logging.info("About to get train cuts")
|
|
||||||
# return load_manifest_lazy(
|
|
||||||
# self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# @lru_cache()
|
|
||||||
# def valid_cuts(self) -> CutSet:
|
|
||||||
# logging.info("About to get dev cuts")
|
|
||||||
# return load_manifest_lazy(
|
|
||||||
# self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# @lru_cache()
|
|
||||||
# def test_cuts(self) -> List[CutSet]:
|
|
||||||
# logging.info("About to get test cuts")
|
|
||||||
# return load_manifest_lazy(
|
|
||||||
# self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
|
|
||||||
# )
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user