add args in asr datamodule

This commit is contained in:
marcoyang 2024-01-05 16:12:24 +08:00
parent 71d530db37
commit 6665469a4e

View File

@ -102,6 +102,20 @@ class SPGISpeechAsrDataModule:
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=False,
help="When enabled, the last batch will be dropped",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--gap",
type=float,
@ -143,7 +157,7 @@ class SPGISpeechAsrDataModule:
group.add_argument(
"--num-workers",
type=int,
default=8,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
@ -176,7 +190,7 @@ class SPGISpeechAsrDataModule:
The state dict for the training sampler.
"""
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms = []
if self.args.enable_musan:
@ -223,11 +237,13 @@ class SPGISpeechAsrDataModule:
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
else:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
logging.info("Using DynamicBucketingSampler.")
@ -236,7 +252,7 @@ class SPGISpeechAsrDataModule:
max_duration=self.args.max_duration,
shuffle=False,
num_buckets=self.args.num_buckets,
drop_last=True,
drop_last=self.args.drop_last,
)
logging.info("About to create train dataloader")
@ -274,10 +290,12 @@ class SPGISpeechAsrDataModule:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
@ -301,6 +319,7 @@ class SPGISpeechAsrDataModule:
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts, max_duration=self.args.max_duration, shuffle=False