From ff53cb045853fd045f658749bc6552c95c11796e Mon Sep 17 00:00:00 2001 From: jinzr Date: Fri, 8 Mar 2024 02:10:57 +0800 Subject: [PATCH] minor fixes --- .../pruned_transducer_stateless7/asr_datamodule.py | 8 +++++--- .../commonvoice_fr.py | 14 ++++++++------ .../ASR/zipformer_prompt_asr/asr_datamodule.py | 8 +++++--- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py index c40d9419b..41009831c 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -381,9 +381,11 @@ class CommonVoiceAsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)() + ), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py index 79cf86b84..da8e62034 100644 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/commonvoice_fr.py @@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, - SingleCutSampler, + SimpleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -315,8 +315,8 @@ class CommonVoiceAsrDataModule: drop_last=self.args.drop_last, ) else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, @@ -383,9 +383,11 @@ class CommonVoiceAsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)(), + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)() + ), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py index 1a4c9a532..552f63905 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/asr_datamodule.py @@ -425,9 +425,11 @@ class LibriHeavyAsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures() + ), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler(