diff --git a/egs/librispeech/ASR/conformer_ctc/gigaspeech_datamodule.py b/egs/librispeech/ASR/conformer_ctc/gigaspeech_datamodule.py new file mode 100644 index 000000000..0ba184d8c --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/gigaspeech_datamodule.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import List, Union + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( + BucketingSampler, + CutConcatenate, + CutMix, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool + + +class GigaSpeechAsrDataModule(DataModule): + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + super().add_arguments(parser) + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-giga", + type=str2bool, + default=False, + help="When enabled, use XL part of GigaSpeech. " + "Otherwise, use XS subset.", + ) + group.add_argument( + "--feature-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + def train_dataloaders(self) -> DataLoader: + logging.info("About to get train cuts") + cuts_train = self.train_cuts() + + logging.info("About to create train dataset") + transforms = None + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + # input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using BucketingSampler.") + train_sampler = BucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + bucket_method="equal_duration", + drop_last=True, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return train_dl + + def valid_dataloaders(self) -> DataLoader: + logging.info("About to get dev cuts") + cuts_valid = self.valid_cuts() + + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = BucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: + cuts = self.test_cuts() + is_list = isinstance(cuts, list) + test_loaders = [] + if not is_list: + cuts = [cuts] + + for cuts_test in cuts: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = BucketingSampler( + cuts_test, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + test_loaders.append(test_dl) + + if is_list: + return test_loaders + else: + return test_loaders[0] + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + # TODO(Liyong Guo): Support S, M, L if needed + if self.args.full_giga: + cuts_train = load_manifest( + self.args.feature_dir / "cuts_XL.json.gz" + ) + else: + cuts_train = load_manifest( + self.args.feature_dir / "cuts_XS.json.gz" + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest( + self.args.feature_dir / "cuts_DEV.json.gz" + ) + return cuts_valid + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + logging.info("About to get dev cuts") + cuts_test = load_manifest( + self.args.feature_dir / "cuts_TEST.json.gz" + ) + return cuts_test diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech.py new file mode 100755 index 000000000..0aab94969 --- /dev/null +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the GigaSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_gigaspeech(): + manifests_dir = Path("data/manifests") + output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + + dataset_parts = ( + "XS", + "S", + "M", + "L", + "XL", + "DEV", + "TEST", + ) + + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=manifests_dir, + prefix="gigaspeech", + suffix="jsonl.gz", + ) + assert manifests is not None + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + if (output_dir / f"cuts_{partition}.json.gz").is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomHdf5Writer, + ) + cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_fbank_gigaspeech() diff --git a/egs/librispeech/ASR/prepare_giga.sh b/egs/librispeech/ASR/prepare_giga.sh new file mode 100755 index 000000000..ecdb2261e --- /dev/null +++ b/egs/librispeech/ASR/prepare_giga.sh @@ -0,0 +1,30 @@ + +dl_dir='/home/storage07/zhangjunbo/data/' +output_dir=/ceph-hw/ly/data/gigaspeech_nb/ + +mkdir -p $output_dir/manifests + +stage=2 +stop_stage=2 +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + echo "Implement and verify gigaspeech downloading later" +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + # subset could be: ["XS", "S", "M", "L", "XL", "DEV" "TEST"] + # Currently only XS DEV TEST are verified + # Others SHOULD also work + subsets="XS DEV TEST" + for subset in $subsets; do + lhotse prepare gigaspeech \ + -j 60 \ + --subset=$subset \ + $dl_dir/GigaSpeech $output_dir/manifests + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 3: Compute fbank for gigaspeech" + mkdir -p $output_dir/fbank + ./local/compute_fbank_gigaspeech.py +fi