support on-the-fly whisper fbank extraction

This commit is contained in:
marcoyang 2024-03-29 11:03:58 +08:00
parent 4d9f2120b3
commit f208431f5c

View File

@ -24,7 +24,15 @@ from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import (
CutSet,
Fbank,
FbankConfig,
load_manifest,
load_manifest_lazy,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate, CutConcatenate,
CutMix, CutMix,
@ -215,6 +223,20 @@ class LibriSpeechAsrDataModule:
help="AudioSamples or PrecomputedFeatures", help="AudioSamples or PrecomputedFeatures",
) )
group.add_argument(
"--use-whisper-fbank",
type=str2bool,
default=False,
help="Use whisper fbank feature as input",
)
group.add_argument(
"--whisper-fbank-n-mels",
type=int,
default=80,
help="Number of mels for whisper fbank, large-v3 uses 128-mel fbank",
)
def train_dataloaders( def train_dataloaders(
self, self,
cuts_train: CutSet, cuts_train: CutSet,
@ -297,9 +319,15 @@ class LibriSpeechAsrDataModule:
# to be strict (e.g. could be randomized) # to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side. # Drop feats to be on the safe side.
if self.args.use_whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels),
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=80))
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(extractor),
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
@ -355,9 +383,15 @@ class LibriSpeechAsrDataModule:
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
if self.args.use_whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels),
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=80))
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), input_strategy=OnTheFlyFeatures(extractor),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else: else:
@ -383,8 +417,15 @@ class LibriSpeechAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
if self.args.use_whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels),
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=80))
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=OnTheFlyFeatures(extractor)
if self.args.on_the_fly_feats if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(), else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,