use KaldifeatFbank extractor

This commit is contained in:
Guo Liyong 2021-11-05 21:25:40 +08:00
parent 83b2705b44
commit a1cdf09655
2 changed files with 18 additions and 7 deletions

View File

@ -9,7 +9,7 @@ from typing import List, Union
from torch.utils.data import DataLoader
from lhotse import CutSet, KaldifeatFbank, FbankConfig, load_manifest
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
@ -261,7 +261,10 @@ class GigaSpeechAsrDataModule(DataModule):
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
KaldifeatFbank(FbankConfig(num_mel_bins=80)),
# To avoid unexpected GPU OOM issue during training,
# I think using the cpu version is safer
# KaldifeatFbank(KaldifeatFbankConfig(device='cuda')),
KaldifeatFbank(KaldifeatFbankConfig()),
num_workers=self.args.giga_num_workers_inner,
),
return_cuts=self.args.giga_return_cuts,
@ -316,7 +319,10 @@ class GigaSpeechAsrDataModule(DataModule):
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
KaldifeatFbank(FbankConfig(num_mel_bins=80)), num_workers=8
# To avoid unexpected GPU OOM issue during training,
# I think using the cpu version is safer
# KaldifeatFbank(KaldifeatFbankConfig(device='cuda')), num_workers=8
KaldifeatFbank(KaldifeatFbankConfig()), num_workers=8
),
return_cuts=self.args.giga_return_cuts,
)
@ -357,7 +363,10 @@ class GigaSpeechAsrDataModule(DataModule):
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=(
OnTheFlyFeatures(KaldifeatFbank(FbankConfig(num_mel_bins=80)), num_workers=8)
# To avoid unexpected GPU OOM issue during training,
# I think using the cpu version is safer
# OnTheFlyFeatures(KaldifeatFbank(KaldifeatFbankConfig(device='cuda')), num_workers=8)
OnTheFlyFeatures(KaldifeatFbank(KaldifeatFbankConfig()), num_workers=8)
if self.args.giga_on_the_fly_feats
else PrecomputedFeatures()
),

View File

@ -15,8 +15,8 @@ import torch
from gigaspeech_datamodule import get_context_suffix
from lhotse import (
CutSet,
Fbank,
FbankConfig,
KaldifeatFbank,
KaldifeatFbankConfig,
LilcomHdf5Writer,
SupervisionSegment,
combine,
@ -183,7 +183,8 @@ def main():
ctx_suffix = get_context_suffix(args, subparser=False)
print("Feature extraction:")
extractor = Fbank(FbankConfig(num_mel_bins=80))
# extractor = Fbank(FbankConfig(num_mel_bins=80))
extractor = KaldifeatFbank(KaldifeatFbankConfig(device='cuda')) # default config uses 80 mel bins already
with get_executor() as ex: # Initialize the executor only once.
for partition, manifests in gigaspeech_manifests.items():
raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
@ -268,6 +269,7 @@ def main():
storage_path=f"{output_dir}/feats_gigaspeech_{partition}",
batch_duration=args.batch_duration,
num_workers=args.num_workers,
storage_type=partial(LilcomHdf5Writer, tick_power=-3),
)