mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
add args in asr datamodule
This commit is contained in:
parent
71d530db37
commit
6665469a4e
@ -102,6 +102,20 @@ class SPGISpeechAsrDataModule:
|
|||||||
help="Determines the maximum duration of a concatenated cut "
|
help="Determines the maximum duration of a concatenated cut "
|
||||||
"relative to the duration of the longest cut in a batch.",
|
"relative to the duration of the longest cut in a batch.",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--drop-last",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, the last batch will be dropped",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, each batch will have the "
|
||||||
|
"field: batch['supervisions']['cut'] with the cuts that "
|
||||||
|
"were used to construct it.",
|
||||||
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--gap",
|
"--gap",
|
||||||
type=float,
|
type=float,
|
||||||
@ -143,7 +157,7 @@ class SPGISpeechAsrDataModule:
|
|||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--num-workers",
|
"--num-workers",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=2,
|
||||||
help="The number of training dataloader workers that "
|
help="The number of training dataloader workers that "
|
||||||
"collect the batches.",
|
"collect the batches.",
|
||||||
)
|
)
|
||||||
@ -176,7 +190,7 @@ class SPGISpeechAsrDataModule:
|
|||||||
The state dict for the training sampler.
|
The state dict for the training sampler.
|
||||||
"""
|
"""
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz")
|
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||||
|
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.enable_musan:
|
if self.args.enable_musan:
|
||||||
@ -223,11 +237,13 @@ class SPGISpeechAsrDataModule:
|
|||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Using DynamicBucketingSampler.")
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
@ -236,7 +252,7 @@ class SPGISpeechAsrDataModule:
|
|||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
drop_last=True,
|
drop_last=self.args.drop_last,
|
||||||
)
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
@ -274,10 +290,12 @@ class SPGISpeechAsrDataModule:
|
|||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid,
|
||||||
@ -301,6 +319,7 @@ class SPGISpeechAsrDataModule:
|
|||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
else PrecomputedFeatures(),
|
else PrecomputedFeatures(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
cuts, max_duration=self.args.max_duration, shuffle=False
|
cuts, max_duration=self.args.max_duration, shuffle=False
|
||||||
|
Loading…
x
Reference in New Issue
Block a user