inexhaustible train dataloaders

This commit is contained in:
Guo Liyong 2021-11-03 17:42:33 +08:00
parent bf98c0fd27
commit f679d6063b
3 changed files with 537 additions and 217 deletions

View File

@ -1,273 +1,311 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Johns Hopkins University (Piotr Żelasko)
# Apache 2.0
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import List, Union
from torch.utils.data import DataLoader
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.dataloading import LhotseDataLoader
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
from icefall.dataset.datamodule import DataModule
def get_context_suffix(args):
if args.context_window is None or args.context_window <= 0.0:
ctx_suffix = ""
else:
ctx_suffix = f"_{args.context_direction}{args.context_window}"
return ctx_suffix
class GigaSpeechAsrDataModule(DataModule):
"""
DataModule for k2 ASR experiments.
DataModule for K2 ASR experiments.
It assumes there is always one train and valid dataloader,
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean and test-other).
It contains all the common data pipeline modules used in ASR experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args):
self.total_train_cuts = 0
self.consumed_cuts = 0
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
title='ASR data related options',
description='These options are used for the preparation of PyTorch DataLoaders '
'from Lhotse CutSet\'s -- they control the effective batch sizes, '
'sampling strategies, applied data augmentations, etc.'
)
group.add_argument(
"--full-giga",
'--feature-dir',
type=Path,
default=Path('exp/data'),
help='Path to directory with train/valid/test cuts.'
)
group.add_argument(
'--max-duration',
type=int,
default=500.0,
help="Maximum pooled recordings duration (seconds) in a single batch.")
group.add_argument(
'--bucketing-sampler',
type=str2bool,
default=False,
help="When enabled, use XL part of GigaSpeech. "
"Otherwise, use XS subset.",
)
help='When enabled, the batches will come from buckets of '
'similar duration (saves padding frames).')
group.add_argument(
"--feature-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
'--num-buckets',
type=int,
default=30,
help="The number of buckets for the BucketingSampler"
"(you might want to increase it for larger datasets).",
)
help='The number of buckets for the BucketingSampler'
'(you might want to increase it for larger datasets).')
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
'--concatenate-cuts',
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
help='When enabled, utterances (cuts) will be concatenated '
'to minimize the amount of padding.')
group.add_argument(
'--duration-factor',
type=float,
default=1.0,
help='Determines the maximum duration of a concatenated cut '
'relative to the duration of the longest cut in a batch.')
group.add_argument(
'--gap',
type=float,
default=1.0,
help='The amount of padding (in seconds) inserted between concatenated cuts. '
'This padding is filled with noise when noise augmentation is used.')
group.add_argument(
'--on-the-fly-feats',
type=str2bool,
default=False,
help='When enabled, use on-the-fly cut mixing and feature extraction. '
'Will drop existing precomputed feature manifests if available.'
)
group.add_argument(
"--return-cuts",
'--shuffle',
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
help='When enabled (=default), the examples will be shuffled for each epoch.'
)
group.add_argument(
'--check-cuts',
type=str2bool,
default=True,
help='When enabled (=default), we will iterate over the whole training cut set '
'to validate it. It should be disabled when using Apache Arrow manifests '
'to avoid an excessive starting time of the script with datasets>1000h.'
)
# GigaSpeech specific arguments
group.add_argument(
"--subset",
type=str,
default="XS",
help="Select the GigaSpeech subset (XS|S|M|L|XL)",
)
group.add_argument(
"--context-window",
type=float,
default=0.0,
help="Training cut duration in seconds. "
"Use 0 to train on supervision segments without acoustic context, with variable cut lengths; "
"number larger than zero will create multi-supervisions cuts with actual acoustic context. ",
)
group.add_argument(
"--context-direction",
type=str,
default="center",
help="If context-window is 0, does nothing. "
"If it's larger than 0, determines in which direction (relative to the supervision) "
"to seek for extra acoustic context. Available values: (left|right|center|random).",
)
group.add_argument(
'--use-context-for-test',
type=str2bool,
default=False,
help='Should we read cuts with acoustic context or without it. '
'(note: for now, they may contain duplicated segments)'
)
group.add_argument(
'--small-dev',
type=str2bool,
default=False,
help='Should we use only 1000 utterances for dev (speeds up training)'
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
def validate_args(self):
if self.args.subset in ['L', 'XL']:
assert (
self.args.shuffle == False
), "For GigaSpeech L/XL, you must use --shuffle 0 to avoid eagerly reading pyarrow manifests."
assert (
self.args.check_cuts == False
), "For GigaSpeech L/XL, you must use --check-cuts 0 to avoid eagerly reading pyarrow manifests."
assert (
self.args.bucketing_sampler == False
), "For GigaSpeech L/XL, you must use --bucketing-sampler 0 to avoid eagerly reading pyarrow manifests."
assert (
self.args.on_the_fly_feats == True
), "For GigaSpeech L/XL, you must use --on-the-fly-feats 1 as we do not pre-compute them by default."
def train_dataloaders(self) -> DataLoader:
self.validate_args()
logging.info("About to get train cuts")
cuts_train = self.train_cuts()
self.total_train_cuts = len(cuts_train)
self.consumed_cuts = 0
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.feature_dir / 'cuts_musan.json.gz')
logging.info("About to create train dataset")
transforms = None
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
logging.info(f'Using cut concatenation with duration factor '
f'{self.args.duration_factor} and gap {self.args.gap}.')
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
# so that if we e.g. mix noise in, it will fill the gaps between different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
]
CutConcatenate(
duration_factor=self.args.duration_factor,
gap=self.args.gap
)
] + transforms
train = K2SpeechRecognitionDataset(
# cuts_train,
cut_transforms=transforms,
# input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
return_cuts=True,
# check_inputs=self.args.check_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
# NOTE: the PerturbSpeed transform should be added only if we remove it from data prep stage.
# # Add on-the-fly speed perturbation; since originally it would have increased epoch
# # size by 3, we will apply prob 2/3 and use 3x more epochs.
# # Speed perturbation probably should come first before concatenation,
# # but in principle the transforms order doesn't have to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
train = K2SpeechRecognitionDataset(
cuts=cuts_train,
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)), num_workers=20),
return_cuts=True,
# check_inputs=self.args.check_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
logging.info('Using BucketingSampler.')
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
num_buckets=self.args.num_buckets
)
else:
logging.info("Using SingleCutSampler.")
logging.info('Using SingleCutSampler.')
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(
#train_dl = DataLoader(
# train,
# sampler=train_sampler,
# batch_size=None,
# num_workers=16,
# persistent_workers=True,
#)
train_dl = LhotseDataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
num_workers=3,
prefetch_factor=5,
)
return train_dl
def inexhaustible_train_dataloaders(self):
return self
def valid_dataloaders(self) -> DataLoader:
self.validate_args()
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()
transforms = []
transforms = [ ]
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
transforms = [ CutConcatenate(
duration_factor=self.args.duration_factor,
gap=self.args.gap)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cuts_valid,
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)), num_workers=8),
return_cuts=True,
check_inputs=self.args.check_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
# cuts_valid,
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
return_cuts=True,
# check_inputs=self.args.check_cuts,
)
valid_sampler = BucketingSampler(
valid_sampler = SingleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
#valid_dl = DataLoader(
# validate,
# sampler=valid_sampler,
# batch_size=None,
# num_workers=8,
# persistent_workers=True,
#)
valid_dl = LhotseDataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
self.validate_args()
cuts = self.test_cuts()
is_list = isinstance(cuts, list)
test_loaders = []
@ -277,23 +315,19 @@ class GigaSpeechAsrDataModule(DataModule):
for cuts_test in cuts:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(
cuts_test, max_duration=self.args.max_duration, shuffle=False
cuts_test,
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)), num_workers=8)
if self.args.on_the_fly_feats
else PrecomputedFeatures()
),
return_cuts=True,
check_inputs=self.args.check_cuts,
)
sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
#test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)
test_dl = LhotseDataLoader(test, sampler=sampler, num_workers=2)
test_loaders.append(test_dl)
if is_list:
@ -304,25 +338,48 @@ class GigaSpeechAsrDataModule(DataModule):
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
# TODO(Liyong Guo): Support S, M, L if needed
if self.args.full_giga:
cuts_train = load_manifest(
self.args.feature_dir / "cuts_XL.json.gz"
)
else:
cuts_train = load_manifest(
self.args.feature_dir / "cuts_XS.json.gz"
)
# Note: for L and XL subsets, we are expecting that the training manifest is stored using pyarrow and pre-shuffled.
cuts_path_ext = 'jsonl.gz' if self.args.subset not in ['L', 'XL'] else 'arrow'
cuts_train = CutSet.from_file(
self.args.feature_dir
/ f"gigaspeech_cuts_{self.args.subset}{get_context_suffix(self.args)}.{cuts_path_ext}"
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(self.args.feature_dir / "cuts_DEV.json.gz")
return cuts_valid
if self.args.use_context_for_test:
path = self.args.feature_dir / f"gigaspeech_cuts_DEV{get_context_suffix(self.args)}.jsonl.gz"
else:
path = self.args.feature_dir / f"gigaspeech_cuts_DEV.jsonl.gz"
logging.info(f"About to get valid cuts from {path}")
cuts_valid = load_manifest(path)
if self.args.small_dev:
return cuts_valid.subset(first=1000)
else:
return cuts_valid
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get dev cuts")
cuts_test = load_manifest(self.args.feature_dir / "cuts_TEST.json.gz")
def test_cuts(self) -> CutSet:
if self.args.use_context_for_test:
path = self.args.feature_dir / f"gigaspeech_cuts_TEST{get_context_suffix(self.args)}.jsonl.gz"
else:
path = self.args.feature_dir / f"gigaspeech_cuts_TEST.jsonl.gz"
logging.info(f"About to get test cuts from {path}")
cuts_test = load_manifest(path)
return cuts_test
def inexhaustible_train_dataloaders(self):
return self
def __iter__(self):
# work horse for inexhuastible_train_dataloaders
while True:
# self.total_train_cuts / self.consumed_cuts should be contained by child class
if self.total_train_cuts == 0 and self.consumed_cuts == self.total_train_cuts:
self.train_dl = self.train_dataloaders()
self.consumed_cuts = 0
for batch in self.train_dl:
self.consumed_cuts += len(batch["supervisions"]["text"])
yield batch

View File

@ -1,30 +0,0 @@
dl_dir='/home/storage07/zhangjunbo/data/'
output_dir=/ceph-hw/ly/data/gigaspeech_nb/
mkdir -p $output_dir/manifests
stage=2
stop_stage=2
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
echo "Implement and verify gigaspeech downloading later"
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# subset could be: ["XS", "S", "M", "L", "XL", "DEV" "TEST"]
# Currently only XS DEV TEST are verified
# Others SHOULD also work
subsets="XS DEV TEST"
for subset in $subsets; do
lhotse prepare gigaspeech \
-j 60 \
--subset=$subset \
$dl_dir/GigaSpeech $output_dir/manifests
done
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 3: Compute fbank for gigaspeech"
mkdir -p $output_dir/fbank
./local/compute_fbank_gigaspeech.py
fi

View File

@ -0,0 +1,293 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Johns Hopkins University (Piotr Żelasko)
# Apache 2.0
import argparse
import os
import re
import subprocess
import sys
from contextlib import contextmanager
from pathlib import Path
import torch
from gigaspeech_datamodule import get_context_suffix
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomHdf5Writer,
SupervisionSegment,
combine,
)
from lhotse.recipes import prepare_gigaspeech, prepare_musan
from lhotse.utils import is_module_available
from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
# slow things down. Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
@contextmanager
def get_executor():
# We'll either return a process pool or a distributed worker pool.
# Note that this has to be a context manager because we might use multiple
# context manager ("with" clauses) inside, and this way everything will
# free up the resources at the right time.
try:
# If this is executed on the CLSP grid, we will try to use the
# Grid Engine to distribute the tasks.
# Other clusters can also benefit from that, provided a cluster-specific wrapper.
# (see https://github.com/pzelasko/plz for reference)
#
# The following must be installed:
# $ pip install dask distributed
# $ pip install git+https://github.com/pzelasko/plz
name = subprocess.check_output("hostname -f", shell=True, text=True)
if name.strip().endswith(".clsp.jhu.edu"):
import plz
from distributed import Client
with plz.setup_cluster() as cluster:
cluster.scale(80)
yield Client(cluster)
return
except:
pass
# No need to return anything - compute_and_store_features
# will just instantiate the pool itself.
yield None
def locate_corpus(*corpus_dirs):
for d in corpus_dirs:
if os.path.exists(d):
return d
print(
"Please create a place on your system to put the downloaded Librispeech data "
"and add it to `corpus_dirs`"
)
sys.exit(1)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num-jobs",
type=int,
default=min(5, os.cpu_count()),
help="Number of parallel jobs.",
)
parser.add_argument(
"--subset",
type=str,
default="XS",
help="Select the GigaSpeech subset (XS|S|M|L|XL)",
)
parser.add_argument(
"--context-window",
type=float,
default=0.0,
help="Training cut duration in seconds. "
"Use 0 to train on supervision segments without acoustic context, with variable cut lengths; "
"number larger than zero will create multi-supervisions cuts with actual acoustic context. ",
)
parser.add_argument(
"--context-direction",
type=str,
default="center",
help="If context-window is 0, does nothing. "
"If it's larger than 0, determines in which direction (relative to the supervision) "
"to seek for extra acoustic context. Available values: (left|right|center|random).",
)
parser.add_argument(
"--precomputed-features",
type=str2bool,
default=True,
help="Should we pre-compute features and store them on disk or not. "
"It is recommended to disable it for L and XL splits as the pre-computation "
"might currently consume excessive memory and time -- use on-the-fly feature "
"extraction in the training script instead.",
)
return parser
# Similar text filtering and normalization procedure as in:
# https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh
def normalize_text(
utt: str,
punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"),
whitespace_pattern=re.compile(r"\s\s+"),
) -> str:
return whitespace_pattern.sub(" ", punct_pattern.sub("", utt))
def has_no_oov(
sup: SupervisionSegment, oov_pattern=re.compile(r"<(SIL|MUSIC|NOISE|OTHER)>")
) -> bool:
return oov_pattern.search(sup.text) is None
def main():
args = get_parser().parse_args()
dataset_parts = [args.subset, "DEV", "TEST"]
if args.subset in ["L", "XL"]:
assert is_module_available("pyarrow"), (
"Running the GigaSpeech recipe for L and XL splits "
"currently requires installing optional dependencies: "
"'pip install pyarrow pandas'."
)
print("Parts we will prepare: ", dataset_parts)
corpus_dir = locate_corpus(
Path("/export/corpora5/gigaspeech"),
Path("/exp/pzelasko/gigaspeech"),
Path("/home/storage07/zhangjunbo/data/GigaSpeech")
)
musan_dir = locate_corpus(
Path("/export/corpora5/JHU/musan"),
Path("/export/common/data/corpora/MUSAN/musan"),
Path("/root/fangjun/data/musan"),
)
output_dir = Path("exp/data")
print("GigaSpeech manifest preparation:")
gigaspeech_manifests = prepare_gigaspeech(
corpus_dir=corpus_dir,
dataset_parts=dataset_parts,
output_dir=output_dir,
num_jobs=args.num_jobs,
)
print("Musan manifest preparation:")
musan_cuts_path = output_dir / "cuts_musan.json.gz"
musan_manifests = prepare_musan(
corpus_dir=musan_dir, output_dir=output_dir, parts=("music", "speech", "noise")
)
ctx_suffix = get_context_suffix(args)
print("Feature extraction:")
extractor = Fbank(FbankConfig(num_mel_bins=80))
with get_executor() as ex: # Initialize the executor only once.
for partition, manifests in gigaspeech_manifests.items():
# For L and XL partition we are going to store the manifest using pyarrow.
cuts_path_ext = "jsonl.gz" if partition not in ["L", "XL"] else "arrow"
raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
cuts_path = (
output_dir / f"gigaspeech_cuts_{partition}{ctx_suffix}.{cuts_path_ext}"
)
if raw_cuts_path.is_file():
print(f"{partition} already exists - skipping checking transcript.")
else:
# Note this step makes the recipe different than LibriSpeech:
# We must filter out some utterances and remove punctuation to be consistent with Kaldi.
print("Filtering OOV utterances from supervisions")
manifests["supervisions"] = manifests["supervisions"].filter(has_no_oov)
print("Normalizing text in", partition)
for sup in manifests["supervisions"]:
sup.text = normalize_text(sup.text)
# Create long-recording cut manifests.
print("Processing", partition)
cut_set = CutSet.from_manifests(
recordings=manifests["recordings"],
supervisions=manifests["supervisions"],
)
# Run data augmentation that needs to be done in the time domain.
if partition not in ["DEV", "TEST"]:
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
cut_set.to_file(raw_cuts_path)
if cuts_path.is_file():
print(
f"{partition} already exists - skipping cutting into sub-segments and feature extraction."
)
else:
try:
# If we skipped initializing `cut_set` because it exists on disk, we'll load it.
# This helps us avoid re-computing the features for different variants of
# context windows.
cut_set
except NameError:
print(f"Reading {partition} raw cuts from disk.")
cut_set = CutSet.from_file(raw_cuts_path)
# Note this step makes the recipe different than LibriSpeech:
# Since recordings are long, the initial CutSet has very long cuts with a plenty of supervisions.
# We cut these into smaller chunks centered around each supervision, possibly adding acoustic
# context.
print(f"About to split {partition} raw cuts into smaller chunks.")
cut_set = cut_set.trim_to_supervisions(
keep_overlapping=False,
min_duration=None
if args.context_window <= 0.0
else args.context_window,
context_direction=args.context_direction,
)
if partition in ["L", "XL"]:
# Before storing manifests in the arrow format, we want to pre-shuffle them,
# as the sampler won't be able to do it later in an efficient manner.
cut_set = cut_set.shuffle()
if args.precomputed_features:
# Extract the features after cutting large recordings into smaller cuts.
# Note: we support very efficient "chunked" feature reads with the argument
# `storage_type=ChunkedLilcomHdf5Writer`, but we don't support efficient
# data augmentation and feature computation for long recordings yet.
# Therefore, we sacrifice some storage for the ability to precompute
# features on shorter chunks, without memory blow-ups.
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_gigaspeech_{partition}",
# when an executor is specified, make more partitions
num_jobs=args.num_jobs if ex is None else 80,
executor=ex,
)
cut_set.to_file(cuts_path)
# Remove cut_set so the next iteration can correctly infer whether it needs to
# load the raw cuts from disk or not.
del cut_set
# Now onto Musan
if not musan_cuts_path.is_file():
print("Extracting features for Musan")
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (
CutSet.from_manifests(
recordings=combine(
part["recordings"] for part in musan_manifests.values()
)
)
.cut_into_windows(10.0)
.filter(lambda c: c.duration > 5)
.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/feats_musan",
num_jobs=args.num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomHdf5Writer,
)
)
musan_cuts.to_file(musan_cuts_path)
if __name__ == "__main__":
main()