mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
add args in asr datamodule
This commit is contained in:
parent
71d530db37
commit
6665469a4e
@ -102,6 +102,20 @@ class SPGISpeechAsrDataModule:
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, the last batch will be dropped",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
@ -143,7 +157,7 @@ class SPGISpeechAsrDataModule:
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=8,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
@ -176,7 +190,7 @@ class SPGISpeechAsrDataModule:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
@ -223,11 +237,13 @@ class SPGISpeechAsrDataModule:
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
@ -236,7 +252,7 @@ class SPGISpeechAsrDataModule:
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=True,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
@ -274,10 +290,12 @@ class SPGISpeechAsrDataModule:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
@ -301,6 +319,7 @@ class SPGISpeechAsrDataModule:
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||
|
Loading…
x
Reference in New Issue
Block a user