diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 814390ad6..b83a61ccf 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -24,7 +24,15 @@ from pathlib import Path from typing import Any, Dict, Optional import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + load_manifest, + load_manifest_lazy, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, CutMix, @@ -215,6 +223,20 @@ class LibriSpeechAsrDataModule: help="AudioSamples or PrecomputedFeatures", ) + group.add_argument( + "--use-whisper-fbank", + type=str2bool, + default=False, + help="Use whisper fbank feature as input", + ) + + group.add_argument( + "--whisper-fbank-n-mels", + type=int, + default=80, + help="Number of mels for whisper fbank, large-v3 uses 128-mel fbank", + ) + def train_dataloaders( self, cuts_train: CutSet, @@ -297,9 +319,15 @@ class LibriSpeechAsrDataModule: # to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # Drop feats to be on the safe side. + if self.args.use_whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels), + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=80)) train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures(extractor), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -355,9 +383,15 @@ class LibriSpeechAsrDataModule: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: + if self.args.use_whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels), + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=80)) validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures(extractor), return_cuts=self.args.return_cuts, ) else: @@ -383,8 +417,15 @@ class LibriSpeechAsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") + if self.args.use_whisper_fbank: + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels), + ) + else: + extractor = Fbank(FbankConfig(num_mel_bins=80)) + test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + input_strategy=OnTheFlyFeatures(extractor) if self.args.on_the_fly_feats else eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts,