mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
support on-the-fly whisper fbank extraction
This commit is contained in:
parent
4d9f2120b3
commit
f208431f5c
@ -24,7 +24,15 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
load_manifest,
|
||||||
|
load_manifest_lazy,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
@ -215,6 +223,20 @@ class LibriSpeechAsrDataModule:
|
|||||||
help="AudioSamples or PrecomputedFeatures",
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--use-whisper-fbank",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Use whisper fbank feature as input",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--whisper-fbank-n-mels",
|
||||||
|
type=int,
|
||||||
|
default=80,
|
||||||
|
help="Number of mels for whisper fbank, large-v3 uses 128-mel fbank",
|
||||||
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
@ -297,9 +319,15 @@ class LibriSpeechAsrDataModule:
|
|||||||
# to be strict (e.g. could be randomized)
|
# to be strict (e.g. could be randomized)
|
||||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||||
# Drop feats to be on the safe side.
|
# Drop feats to be on the safe side.
|
||||||
|
if self.args.use_whisper_fbank:
|
||||||
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(extractor),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -355,9 +383,15 @@ class LibriSpeechAsrDataModule:
|
|||||||
|
|
||||||
logging.info("About to create dev dataset")
|
logging.info("About to create dev dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
|
if self.args.use_whisper_fbank:
|
||||||
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=OnTheFlyFeatures(extractor),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -383,8 +417,15 @@ class LibriSpeechAsrDataModule:
|
|||||||
|
|
||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
|
if self.args.use_whisper_fbank:
|
||||||
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=self.args.whisper_fbank_n_mels),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=OnTheFlyFeatures(extractor)
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
else eval(self.args.input_strategy)(),
|
else eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user