mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
support on-the-fly whisper fbank extraction
This commit is contained in:
parent
4d9f2120b3
commit
f208431f5c
@ -24,7 +24,15 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
load_manifest,
|
||||
load_manifest_lazy,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
@ -215,6 +223,20 @@ class LibriSpeechAsrDataModule:
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--use-whisper-fbank",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Use whisper fbank feature as input",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--whisper-fbank-n-mels",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Number of mels for whisper fbank, large-v3 uses 128-mel fbank",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
@ -297,9 +319,15 @@ class LibriSpeechAsrDataModule:
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
if self.args.use_whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels),
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_strategy=OnTheFlyFeatures(extractor),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -355,9 +383,15 @@ class LibriSpeechAsrDataModule:
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
if self.args.use_whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels),
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_strategy=OnTheFlyFeatures(extractor),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
@ -383,8 +417,15 @@ class LibriSpeechAsrDataModule:
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
if self.args.use_whisper_fbank:
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels),
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
input_strategy=OnTheFlyFeatures(extractor)
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
|
Loading…
x
Reference in New Issue
Block a user