mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
use KaldifeatFbank extractor
This commit is contained in:
parent
83b2705b44
commit
a1cdf09655
@ -9,7 +9,7 @@ from typing import List, Union
|
|||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from lhotse import CutSet, KaldifeatFbank, FbankConfig, load_manifest
|
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, load_manifest
|
||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
BucketingSampler,
|
BucketingSampler,
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
@ -261,7 +261,10 @@ class GigaSpeechAsrDataModule(DataModule):
|
|||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(
|
input_strategy=OnTheFlyFeatures(
|
||||||
KaldifeatFbank(FbankConfig(num_mel_bins=80)),
|
# To avoid unexpected GPU OOM issue during training,
|
||||||
|
# I think using the cpu version is safer
|
||||||
|
# KaldifeatFbank(KaldifeatFbankConfig(device='cuda')),
|
||||||
|
KaldifeatFbank(KaldifeatFbankConfig()),
|
||||||
num_workers=self.args.giga_num_workers_inner,
|
num_workers=self.args.giga_num_workers_inner,
|
||||||
),
|
),
|
||||||
return_cuts=self.args.giga_return_cuts,
|
return_cuts=self.args.giga_return_cuts,
|
||||||
@ -316,7 +319,10 @@ class GigaSpeechAsrDataModule(DataModule):
|
|||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(
|
input_strategy=OnTheFlyFeatures(
|
||||||
KaldifeatFbank(FbankConfig(num_mel_bins=80)), num_workers=8
|
# To avoid unexpected GPU OOM issue during training,
|
||||||
|
# I think using the cpu version is safer
|
||||||
|
# KaldifeatFbank(KaldifeatFbankConfig(device='cuda')), num_workers=8
|
||||||
|
KaldifeatFbank(KaldifeatFbankConfig()), num_workers=8
|
||||||
),
|
),
|
||||||
return_cuts=self.args.giga_return_cuts,
|
return_cuts=self.args.giga_return_cuts,
|
||||||
)
|
)
|
||||||
@ -357,7 +363,10 @@ class GigaSpeechAsrDataModule(DataModule):
|
|||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=(
|
input_strategy=(
|
||||||
OnTheFlyFeatures(KaldifeatFbank(FbankConfig(num_mel_bins=80)), num_workers=8)
|
# To avoid unexpected GPU OOM issue during training,
|
||||||
|
# I think using the cpu version is safer
|
||||||
|
# OnTheFlyFeatures(KaldifeatFbank(KaldifeatFbankConfig(device='cuda')), num_workers=8)
|
||||||
|
OnTheFlyFeatures(KaldifeatFbank(KaldifeatFbankConfig()), num_workers=8)
|
||||||
if self.args.giga_on_the_fly_feats
|
if self.args.giga_on_the_fly_feats
|
||||||
else PrecomputedFeatures()
|
else PrecomputedFeatures()
|
||||||
),
|
),
|
||||||
|
@ -15,8 +15,8 @@ import torch
|
|||||||
from gigaspeech_datamodule import get_context_suffix
|
from gigaspeech_datamodule import get_context_suffix
|
||||||
from lhotse import (
|
from lhotse import (
|
||||||
CutSet,
|
CutSet,
|
||||||
Fbank,
|
KaldifeatFbank,
|
||||||
FbankConfig,
|
KaldifeatFbankConfig,
|
||||||
LilcomHdf5Writer,
|
LilcomHdf5Writer,
|
||||||
SupervisionSegment,
|
SupervisionSegment,
|
||||||
combine,
|
combine,
|
||||||
@ -183,7 +183,8 @@ def main():
|
|||||||
ctx_suffix = get_context_suffix(args, subparser=False)
|
ctx_suffix = get_context_suffix(args, subparser=False)
|
||||||
|
|
||||||
print("Feature extraction:")
|
print("Feature extraction:")
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=80))
|
# extractor = Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
extractor = KaldifeatFbank(KaldifeatFbankConfig(device='cuda')) # default config uses 80 mel bins already
|
||||||
with get_executor() as ex: # Initialize the executor only once.
|
with get_executor() as ex: # Initialize the executor only once.
|
||||||
for partition, manifests in gigaspeech_manifests.items():
|
for partition, manifests in gigaspeech_manifests.items():
|
||||||
raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
|
raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
|
||||||
@ -268,6 +269,7 @@ def main():
|
|||||||
storage_path=f"{output_dir}/feats_gigaspeech_{partition}",
|
storage_path=f"{output_dir}/feats_gigaspeech_{partition}",
|
||||||
batch_duration=args.batch_duration,
|
batch_duration=args.batch_duration,
|
||||||
num_workers=args.num_workers,
|
num_workers=args.num_workers,
|
||||||
|
storage_type=partial(LilcomHdf5Writer, tick_power=-3),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user