latest snowfall gigaspeech script

This commit is contained in:
Guo Liyong 2021-11-04 16:49:56 +08:00
parent 343f99305f
commit 83b2705b44
2 changed files with 182 additions and 128 deletions

View File

@ -2,13 +2,14 @@
# Apache 2.0 # Apache 2.0
import argparse import argparse
import logging import logging
import warnings
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import List, Union
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, KaldifeatFbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
CutConcatenate, CutConcatenate,
@ -24,11 +25,17 @@ from icefall.utils import str2bool
from icefall.dataset.datamodule import DataModule from icefall.dataset.datamodule import DataModule
def get_context_suffix(args): def get_context_suffix(args, subparser=True):
if subparser:
if args.giga_context_window is None or args.giga_context_window <= 0.0: if args.giga_context_window is None or args.giga_context_window <= 0.0:
ctx_suffix = "" ctx_suffix = ""
else: else:
ctx_suffix = f"_{args.giga_context_direction}{args.giga_context_window}" ctx_suffix = f"_{args.giga_context_direction}{args.giga_context_window}"
else:
if args.context_window is None or args.context_window <= 0.0:
ctx_suffix = ""
else:
ctx_suffix = f"_{args.context_direction}{args.context_window}"
return ctx_suffix return ctx_suffix
@ -36,13 +43,14 @@ class GigaSpeechAsrDataModule(DataModule):
""" """
DataModule for K2 ASR experiments. DataModule for K2 ASR experiments.
It assumes there is always one train and valid dataloader, It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean and test-other).
It contains all the common data pipeline modules used in ASR experiments, e.g.: It contains all the common data pipeline modules used in ASR experiments, e.g.:
- dynamic batch size, - dynamic batch size,
- bucketing samplers, - bucketing samplers,
- cut concatenation, - cut concatenation,
- augmentation, - augmentation,
- on-the-fly feature extraction - on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks. This class should be derived for specific corpora used in ASR tasks.
""" """
@ -57,82 +65,102 @@ class GigaSpeechAsrDataModule(DataModule):
parser = subparsers.add_parser(name='giga') parser = subparsers.add_parser(name='giga')
super().add_arguments(parser) super().add_arguments(parser)
group = parser.add_argument_group( group = parser.add_argument_group(
title='ASR data related options', title="ASR data related options",
description='These options are used for the preparation of PyTorch DataLoaders ' description="These options are used for the preparation of PyTorch DataLoaders "
'from Lhotse CutSet\'s -- they control the effective batch sizes, ' "from Lhotse CutSet's -- they control the effective batch sizes, "
'sampling strategies, applied data augmentations, etc.' "sampling strategies, applied data augmentations, etc.",
) )
group.add_argument( group.add_argument(
'--feature-dir', "--feature-dir",
dest="giga_feature_dir", dest="giga_feature_dir",
type=Path, type=Path,
default=Path('exp/giga_data'), default=Path('exp/giga_data'),
help='Path to directory with train/valid/test cuts.' help="Path to directory with train/valid/test cuts.",
) )
group.add_argument( group.add_argument(
'--max-duration', "--max-duration",
dest="giga_max_duration", dest="giga_max_duration",
type=int, type=int,
default=500.0, default=500.0,
help="Maximum pooled recordings duration (seconds) in a single batch.") help="Maximum pooled recordings duration (seconds) in a single batch.",
)
group.add_argument( group.add_argument(
'--bucketing-sampler', "--bucketing-sampler",
dest="giga_bucketing_sampler", dest="giga_bucketing_sampler",
type=str2bool, type=str2bool,
default=False, default=False,
help='When enabled, the batches will come from buckets of ' help="When enabled, the batches will come from buckets of "
'similar duration (saves padding frames).') "similar duration (saves padding frames).",
)
group.add_argument( group.add_argument(
'--num-buckets', "--num-buckets",
dest="giga_num_buckets",
type=int, type=int,
default=30, default=30,
dest="giga_num_buckets", help="The number of buckets for the BucketingSampler"
help='The number of buckets for the BucketingSampler' "(you might want to increase it for larger datasets).",
'(you might want to increase it for larger datasets).') )
group.add_argument( group.add_argument(
'--concatenate-cuts', "--concatenate-cuts",
dest="giga_concatenate_cuts", dest="giga_concatenate_cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help='When enabled, utterances (cuts) will be concatenated ' help="When enabled, utterances (cuts) will be concatenated "
'to minimize the amount of padding.') "to minimize the amount of padding.",
)
group.add_argument( group.add_argument(
'--duration-factor', "--duration-factor",
dest="giga_duration_factor", dest="giga_duration_factor",
type=float, type=float,
default=1.0, default=1.0,
help='Determines the maximum duration of a concatenated cut ' help="Determines the maximum duration of a concatenated cut "
'relative to the duration of the longest cut in a batch.') "relative to the duration of the longest cut in a batch.",
)
group.add_argument( group.add_argument(
'--gap', "--gap",
dest="giga_gap", dest="giga_gap",
type=float, type=float,
default=1.0, default=1.0,
help='The amount of padding (in seconds) inserted between concatenated cuts. ' help="The amount of padding (in seconds) inserted between concatenated cuts. "
'This padding is filled with noise when noise augmentation is used.') "This padding is filled with noise when noise augmentation is used.",
)
group.add_argument( group.add_argument(
'--on-the-fly-feats', "--on-the-fly-feats",
dest="giga_on_the_fly_feats", dest="giga_on_the_fly_feats",
type=str2bool, type=str2bool,
default=False, default=False,
help='When enabled, use on-the-fly cut mixing and feature extraction. ' help="When enabled, use on-the-fly cut mixing and feature extraction. "
'Will drop existing precomputed feature manifests if available.' "Will drop existing precomputed feature manifests if available.",
) )
group.add_argument( group.add_argument(
'--shuffle', "--shuffle",
dest="giga_shuffle", dest="giga_shuffle",
type=str2bool, type=str2bool,
default=True, default=True,
help='When enabled (=default), the examples will be shuffled for each epoch.' help="When enabled (=default), the examples will be shuffled for each epoch.",
) )
group.add_argument( group.add_argument(
'--check-cuts', "--return-cuts",
dest="giga_check_cuts", dest="giga_return_cuts",
type=str2bool, type=str2bool,
default=True, default=True,
help='When enabled (=default), we will iterate over the whole training cut set ' help="When enabled, each batch will have the field: batch['supervisions']['cut']"
'to validate it. It should be disabled when using Apache Arrow manifests ' " with the cuts that were used to construct it.",
'to avoid an excessive starting time of the script with datasets>1000h.' )
group.add_argument(
"--num-workers",
dest="giga_num_workers",
type=int,
default=4,
help="The number of training dataloader workers that collect the batches.",
)
group.add_argument(
"--num-workers-inner",
dest="giga_num_workers_inner",
type=int,
default=16,
help="The number of sub-workers (replicated for each of training dataloader"
" workers) that parallelize the I/O to collect each batch.",
) )
# GigaSpeech specific arguments # GigaSpeech specific arguments
@ -162,35 +190,36 @@ class GigaSpeechAsrDataModule(DataModule):
"to seek for extra acoustic context. Available values: (left|right|center|random).", "to seek for extra acoustic context. Available values: (left|right|center|random).",
) )
group.add_argument( group.add_argument(
'--use-context-for-test', "--use-context-for-test",
dest="giga_use_context_for_text", dest="giga_use_context_for_text",
type=str2bool, type=str2bool,
default=False, default=False,
help='Should we read cuts with acoustic context or without it. ' help="Should we read cuts with acoustic context or without it. "
'(note: for now, they may contain duplicated segments)' "(note: for now, they may contain duplicated segments)",
) )
group.add_argument( group.add_argument(
'--small-dev', "--small-dev",
dest="giga_small_dev", dest="giga_small_dev",
type=str2bool, type=str2bool,
default=False, default=False,
help='Should we use only 1000 utterances for dev (speeds up training)' help="Should we use only 1000 utterances for dev (speeds up training)",
) )
def validate_args(self): def validate_args(self):
if self.args.giga_subset in ['L', 'XL']: if self.args.giga_subset in ["L", "XL"]:
assert ( assert (
self.args.giga_shuffle == False self.args.giga_shuffle == False
), "For GigaSpeech L/XL, you must use --shuffle 0 to avoid eagerly reading pyarrow manifests." ), "For GigaSpeech L/XL, you must use --shuffle 0 to avoid eagerly reading pyarrow manifests."
assert (
self.args.giga_check_cuts == False
), "For GigaSpeech L/XL, you must use --check-cuts 0 to avoid eagerly reading pyarrow manifests."
assert ( assert (
self.args.giga_bucketing_sampler == False self.args.giga_bucketing_sampler == False
), "For GigaSpeech L/XL, you must use --bucketing-sampler 0 to avoid eagerly reading pyarrow manifests." ), "For GigaSpeech L/XL, you must use --bucketing-sampler 0 to avoid eagerly reading pyarrow manifests."
assert ( # compute_and_store_features_batch is efficient for L/XL subsets.
self.args.giga_on_the_fly_feats == True # if not self.args.giga_on_the_fly_feats:
), "For GigaSpeech L/XL, you must use --on-the-fly-feats 1 as we do not pre-compute them by default." # warnings.warn(
# "For GigaSpeech L/XL, we advise to set --on-the-fly-feats 1,"
# " as we do not pre-compute them by default. If you pre-computed them,"
# " ignore this warning."
# )
def train_dataloaders(self) -> DataLoader: def train_dataloaders(self) -> DataLoader:
self.validate_args() self.validate_args()
@ -200,27 +229,26 @@ class GigaSpeechAsrDataModule(DataModule):
self.consumed_cuts = 0 self.consumed_cuts = 0
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.giga_feature_dir / 'cuts_musan.json.gz') cuts_musan = load_manifest(self.args.giga_feature_dir / "cuts_musan.json.gz")
logging.info("About to create train dataset") logging.info("About to create train dataset")
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
if self.args.giga_concatenate_cuts: if self.args.giga_concatenate_cuts:
logging.info(f'Using cut concatenation with duration factor ' logging.info(
f'{self.args.giga_duration_factor} and gap {self.args.giga_gap}.') f"Using cut concatenation with duration factor "
f"{self.args.giga_duration_factor} and gap {self.args.giga_gap}."
)
# Cut concatenation should be the first transform in the list, # Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between different utterances. # so that if we e.g. mix noise in, it will fill the gaps between different utterances.
transforms = [ transforms = [
CutConcatenate( CutConcatenate(
duration_factor=self.args.giga_duration_factor, duration_factor=self.args.giga_duration_factor, gap=self.args.giga_gap
gap=self.args.giga_gap
) )
] + transforms ] + transforms
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
# cuts_train,
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=True, return_cuts=self.args.giga_return_cuts,
# check_inputs=self.args.giga_check_cuts,
) )
if self.args.giga_on_the_fly_feats: if self.args.giga_on_the_fly_feats:
@ -231,23 +259,24 @@ class GigaSpeechAsrDataModule(DataModule):
# # but in principle the transforms order doesn't have to be strict (e.g. could be randomized) # # but in principle the transforms order doesn't have to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2 / 3)] + transforms
train = K2SpeechRecognitionDataset( train = K2SpeechRecognitionDataset(
cuts=cuts_train,
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)), num_workers=20), input_strategy=OnTheFlyFeatures(
return_cuts=True, KaldifeatFbank(FbankConfig(num_mel_bins=80)),
# check_inputs=self.args.giga_check_cuts, num_workers=self.args.giga_num_workers_inner,
),
return_cuts=self.args.giga_return_cuts,
) )
if self.args.giga_bucketing_sampler: if self.args.giga_bucketing_sampler:
logging.info('Using BucketingSampler.') logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler( train_sampler = BucketingSampler(
cuts_train, cuts_train,
max_duration=self.args.giga_max_duration, max_duration=self.args.giga_max_duration,
shuffle=self.args.giga_shuffle, shuffle=self.args.giga_shuffle,
num_buckets=self.args.giga_num_buckets num_buckets=self.args.giga_num_buckets,
) )
else: else:
logging.info('Using SingleCutSampler.') logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SingleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.giga_max_duration, max_duration=self.args.giga_max_duration,
@ -264,14 +293,11 @@ class GigaSpeechAsrDataModule(DataModule):
train_dl = LhotseDataLoader( train_dl = LhotseDataLoader(
train, train,
sampler=train_sampler, sampler=train_sampler,
num_workers=3, num_workers=self.args.giga_num_workers,
prefetch_factor=5, prefetch_factor=5,
) )
return train_dl return train_dl
def inexhaustible_train_dataloaders(self):
return self
def valid_dataloaders(self) -> DataLoader: def valid_dataloaders(self) -> DataLoader:
self.validate_args() self.validate_args()
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
@ -279,27 +305,25 @@ class GigaSpeechAsrDataModule(DataModule):
transforms = [] transforms = []
if self.args.giga_concatenate_cuts: if self.args.giga_concatenate_cuts:
transforms = [ CutConcatenate( transforms = [
duration_factor=self.args.giga_duration_factor, CutConcatenate(
gap=self.args.giga_gap) duration_factor=self.args.giga_duration_factor, gap=self.args.giga_gap
)
] + transforms ] + transforms
logging.info("About to create dev dataset") logging.info("About to create dev dataset")
if self.args.giga_on_the_fly_feats: if self.args.giga_on_the_fly_feats:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
cuts_valid,
cut_transforms=transforms, cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)), num_workers=8), input_strategy=OnTheFlyFeatures(
return_cuts=True, KaldifeatFbank(FbankConfig(num_mel_bins=80)), num_workers=8
check_inputs=self.args.giga_check_cuts, ),
return_cuts=self.args.giga_return_cuts,
) )
else: else:
validate = K2SpeechRecognitionDataset( validate = K2SpeechRecognitionDataset(
# cuts_valid,
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=True, return_cuts=self.args.giga_return_cuts,
# check_inputs=self.args.giga_check_cuts,
) )
valid_sampler = SingleCutSampler( valid_sampler = SingleCutSampler(
cuts_valid, cuts_valid,
@ -332,14 +356,12 @@ class GigaSpeechAsrDataModule(DataModule):
for cuts_test in cuts: for cuts_test in cuts:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
cuts_test,
input_strategy=( input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)), num_workers=8) OnTheFlyFeatures(KaldifeatFbank(FbankConfig(num_mel_bins=80)), num_workers=8)
if self.args.giga_on_the_fly_feats if self.args.giga_on_the_fly_feats
else PrecomputedFeatures() else PrecomputedFeatures()
), ),
return_cuts=True, return_cuts=self.args.giga_return_cuts,
check_inputs=self.args.giga_check_cuts,
) )
sampler = SingleCutSampler(cuts_test, max_duration=self.args.giga_max_duration) sampler = SingleCutSampler(cuts_test, max_duration=self.args.giga_max_duration)
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")
@ -355,18 +377,30 @@ class GigaSpeechAsrDataModule(DataModule):
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
# Note: for L and XL subsets, we are expecting that the training manifest is stored using pyarrow and pre-shuffled. path = (
cuts_path_ext = 'jsonl.gz' if self.args.giga_subset not in ['L', 'XL'] else 'arrow'
cuts_train = CutSet.from_file(
self.args.giga_feature_dir self.args.giga_feature_dir
/ f"gigaspeech_cuts_{self.args.giga_subset}{get_context_suffix(self.args)}.{cuts_path_ext}" / f"gigaspeech_cuts_{self.args.giga_subset}{get_context_suffix(self.args)}.jsonl.gz"
) )
if self.args.giga_subset in ["L", "XL"]:
# "L" and "XL" partitions are large enough that we have to read their manifests lazily;
# The "CutSet" holds a file handle and reads the items sequentially on-the-fly to avoid
# wasting memory and time pre-reading everything. Some operations on "CutSet" won't work,
# e.g. shuffling (or they would have read everything into memory in the process).
# We expect that the manifests read lazily are pre-shuffled, otherwise you might experience
# issues with convergence.
cuts_train = CutSet.from_jsonl_lazy(path)
else:
# For other subsets, just read everything into memory.
cuts_train = CutSet.from_file(path)
return cuts_train return cuts_train
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
if self.args.giga_use_context_for_test: if self.args.giga_use_context_for_test:
path = self.args.giga_feature_dir / f"gigaspeech_cuts_DEV{get_context_suffix(self.args)}.jsonl.gz" path = (
self.args.giga_feature_dir
/ f"gigaspeech_cuts_DEV{get_context_suffix(self.args)}.jsonl.gz"
)
else: else:
path = self.args.giga_feature_dir / f"gigaspeech_cuts_DEV.jsonl.gz" path = self.args.giga_feature_dir / f"gigaspeech_cuts_DEV.jsonl.gz"
logging.info(f"About to get valid cuts from {path}") logging.info(f"About to get valid cuts from {path}")
@ -379,7 +413,10 @@ class GigaSpeechAsrDataModule(DataModule):
@lru_cache() @lru_cache()
def test_cuts(self) -> CutSet: def test_cuts(self) -> CutSet:
if self.args.giga_use_context_for_test: if self.args.giga_use_context_for_test:
path = self.args.giga_feature_dir / f"gigaspeech_cuts_TEST{get_context_suffix(self.args)}.jsonl.gz" path = (
self.args.giga_feature_dir
/ f"gigaspeech_cuts_TEST{get_context_suffix(self.args)}.jsonl.gz"
)
else: else:
path = self.args.giga_feature_dir / f"gigaspeech_cuts_TEST.jsonl.gz" path = self.args.giga_feature_dir / f"gigaspeech_cuts_TEST.jsonl.gz"
logging.info(f"About to get test cuts from {path}") logging.info(f"About to get test cuts from {path}")

View File

@ -22,7 +22,6 @@ from lhotse import (
combine, combine,
) )
from lhotse.recipes import prepare_gigaspeech, prepare_musan from lhotse.recipes import prepare_gigaspeech, prepare_musan
from lhotse.utils import is_module_available
from icefall.utils import str2bool from icefall.utils import str2bool
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and # Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
@ -81,7 +80,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--num-jobs", "--num-jobs",
type=int, type=int,
default=min(5, os.cpu_count()), default=min(15, os.cpu_count()),
help="Number of parallel jobs.", help="Number of parallel jobs.",
) )
parser.add_argument( parser.add_argument(
@ -115,6 +114,19 @@ def get_parser():
"might currently consume excessive memory and time -- use on-the-fly feature " "might currently consume excessive memory and time -- use on-the-fly feature "
"extraction in the training script instead.", "extraction in the training script instead.",
) )
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of workers for compute_and_store_features_batch.",
)
parser.add_argument(
"--batch-duration",
type=float,
default=600.0,
help="The maximum number of audio seconds in a batch"
"for compute_and_store_features_batch.",
)
return parser return parser
@ -139,12 +151,6 @@ def has_no_oov(
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
dataset_parts = [args.subset, "DEV", "TEST"] dataset_parts = [args.subset, "DEV", "TEST"]
if args.subset in ["L", "XL"]:
assert is_module_available("pyarrow"), (
"Running the GigaSpeech recipe for L and XL splits "
"currently requires installing optional dependencies: "
"'pip install pyarrow pandas'."
)
print("Parts we will prepare: ", dataset_parts) print("Parts we will prepare: ", dataset_parts)
@ -159,7 +165,7 @@ def main():
Path("/root/fangjun/data/musan"), Path("/root/fangjun/data/musan"),
) )
output_dir = Path("exp/data") output_dir = Path("exp/giga_data")
print("GigaSpeech manifest preparation:") print("GigaSpeech manifest preparation:")
gigaspeech_manifests = prepare_gigaspeech( gigaspeech_manifests = prepare_gigaspeech(
corpus_dir=corpus_dir, corpus_dir=corpus_dir,
@ -174,21 +180,19 @@ def main():
corpus_dir=musan_dir, output_dir=output_dir, parts=("music", "speech", "noise") corpus_dir=musan_dir, output_dir=output_dir, parts=("music", "speech", "noise")
) )
ctx_suffix = get_context_suffix(args) ctx_suffix = get_context_suffix(args, subparser=False)
print("Feature extraction:") print("Feature extraction:")
extractor = Fbank(FbankConfig(num_mel_bins=80)) extractor = Fbank(FbankConfig(num_mel_bins=80))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, manifests in gigaspeech_manifests.items(): for partition, manifests in gigaspeech_manifests.items():
# For L and XL partition we are going to store the manifest using pyarrow.
cuts_path_ext = "jsonl.gz" if partition not in ["L", "XL"] else "arrow"
raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz" raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz"
cuts_path = ( cuts_path = (
output_dir / f"gigaspeech_cuts_{partition}{ctx_suffix}.{cuts_path_ext}" output_dir / f"gigaspeech_cuts_{partition}{ctx_suffix}.jsonl.gz"
) )
if raw_cuts_path.is_file(): if raw_cuts_path.is_file():
print(f"{partition} already exists - skipping checking transcript.") print(f"{partition} already exists - skipping feature extraction.")
else: else:
# Note this step makes the recipe different than LibriSpeech: # Note this step makes the recipe different than LibriSpeech:
# We must filter out some utterances and remove punctuation to be consistent with Kaldi. # We must filter out some utterances and remove punctuation to be consistent with Kaldi.
@ -217,7 +221,7 @@ def main():
if cuts_path.is_file(): if cuts_path.is_file():
print( print(
f"{partition} already exists - skipping cutting into sub-segments and feature extraction." f"{partition} already exists - skipping cutting into sub-segments."
) )
else: else:
try: try:
@ -241,7 +245,7 @@ def main():
context_direction=args.context_direction, context_direction=args.context_direction,
) )
if partition in ["L", "XL"]: if partition in ["L", "XL"]:
# Before storing manifests in the arrow format, we want to pre-shuffle them, # Before storing manifests in, we want to pre-shuffle them,
# as the sampler won't be able to do it later in an efficient manner. # as the sampler won't be able to do it later in an efficient manner.
cut_set = cut_set.shuffle() cut_set = cut_set.shuffle()
@ -252,14 +256,21 @@ def main():
# data augmentation and feature computation for long recordings yet. # data augmentation and feature computation for long recordings yet.
# Therefore, we sacrifice some storage for the ability to precompute # Therefore, we sacrifice some storage for the ability to precompute
# features on shorter chunks, without memory blow-ups. # features on shorter chunks, without memory blow-ups.
cut_set = cut_set.compute_and_store_features( # cut_set = cut_set.compute_and_store_features(
# extractor=extractor,
# storage_path=f"{output_dir}/feats_gigaspeech_{partition}",
# # when an executor is specified, make more partitions
# num_jobs=args.num_jobs if ex is None else 80,
# executor=ex,
# )
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=f"{output_dir}/feats_gigaspeech_{partition}", storage_path=f"{output_dir}/feats_gigaspeech_{partition}",
# when an executor is specified, make more partitions batch_duration=args.batch_duration,
num_jobs=args.num_jobs if ex is None else 80, num_workers=args.num_workers,
executor=ex,
) )
cut_set.to_file(cuts_path) cut_set.to_file(cuts_path)
# Remove cut_set so the next iteration can correctly infer whether it needs to # Remove cut_set so the next iteration can correctly infer whether it needs to
@ -278,13 +289,19 @@ def main():
) )
.cut_into_windows(10.0) .cut_into_windows(10.0)
.filter(lambda c: c.duration > 5) .filter(lambda c: c.duration > 5)
.compute_and_store_features( .compute_and_store_features_batch(
extractor=extractor, extractor=extractor,
storage_path=f"{output_dir}/feats_musan", storage_path=f"{output_dir}/feats_musan",
num_jobs=args.num_jobs if ex is None else 80, batch_duration=args.batch_duration,
executor=ex, num_workers=args.num_workers,
storage_type=LilcomHdf5Writer,
) )
# .compute_and_store_features(
# extractor=extractor,
# storage_path=f"{output_dir}/feats_musan",
# num_jobs=args.num_jobs if ex is None else 80,
# executor=ex,
# storage_type=LilcomHdf5Writer,
# )
) )
musan_cuts.to_file(musan_cuts_path) musan_cuts.to_file(musan_cuts_path)