add args in asr datamodule

This commit is contained in:
marcoyang 2024-01-05 16:12:24 +08:00
parent 71d530db37
commit 6665469a4e

View File

@ -102,6 +102,20 @@ class SPGISpeechAsrDataModule:
help="Determines the maximum duration of a concatenated cut " help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.", "relative to the duration of the longest cut in a batch.",
) )
group.add_argument(
"--drop-last",
type=str2bool,
default=False,
help="When enabled, the last batch will be dropped",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument( group.add_argument(
"--gap", "--gap",
type=float, type=float,
@ -143,7 +157,7 @@ class SPGISpeechAsrDataModule:
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=8, default=2,
help="The number of training dataloader workers that " help="The number of training dataloader workers that "
"collect the batches.", "collect the batches.",
) )
@ -176,7 +190,7 @@ class SPGISpeechAsrDataModule:
The state dict for the training sampler. The state dict for the training sampler.
""" """
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
@ -223,11 +237,13 @@ class SPGISpeechAsrDataModule:
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
) )
else: else:
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
) )
logging.info("Using DynamicBucketingSampler.") logging.info("Using DynamicBucketingSampler.")
@ -236,7 +252,7 @@ class SPGISpeechAsrDataModule:
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
num_buckets=self.args.num_buckets, num_buckets=self.args.num_buckets,
drop_last=True, drop_last=self.args.drop_last,
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
@ -274,10 +290,12 @@ class SPGISpeechAsrDataModule:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
) )
else: else:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=self.args.return_cuts,
) )
valid_sampler = DynamicBucketingSampler( valid_sampler = DynamicBucketingSampler(
cuts_valid, cuts_valid,
@ -301,6 +319,7 @@ class SPGISpeechAsrDataModule:
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
else PrecomputedFeatures(), else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False cuts, max_duration=self.args.max_duration, shuffle=False